From ddcc822e37035685b5dd44c9418c512e399e4756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Thu, 16 Jun 2022 17:46:28 +0200 Subject: [PATCH 001/526] feature: added get feature group by session or role --- src/sagemaker/__init__.py | 2 + src/sagemaker/feature_group_utils.py | 97 ++++++++++++++ .../feature_store/feature_definition.py | 2 +- src/sagemaker/feature_store/feature_group.py | 11 +- src/sagemaker/feature_store/inputs.py | 2 +- src/sagemaker/utils.py | 85 ++++++++---- tests/integ/test_feature_store.py | 125 ++++++++++++++---- tests/integ/test_model_monitor.py | 4 +- tests/integ/test_monitoring_files.py | 12 +- 9 files changed, 272 insertions(+), 68 deletions(-) create mode 100644 src/sagemaker/feature_group_utils.py diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 7e45d7c721..12fd7051f1 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -62,4 +62,6 @@ from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401 from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401 +from sagemaker.feature_group_utils import get_feature_group_as_dataframe + __version__ = importlib_metadata.version("sagemaker") diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py new file mode 100644 index 0000000000..ea81a925fb --- /dev/null +++ b/src/sagemaker/feature_group_utils.py @@ -0,0 +1,97 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Utilities for working with FeatureGroup and FeatureStores. + + +""" + +import re +import logging + +from pandas import DataFrame + +from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.utils import get_session_from_role + +logger = logging.getLogger(__name__) + + +def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, + query: str = str('SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE ' + + 'is_deleted=False'), + role: str = None, region: str = None, session=None, + event_time_feature_name: str = None, latest_ingestion: bool = True, + logger_level: int = logging.INFO, + **pandas_read_csv_kwargs) -> DataFrame: + """ + Description: + Method to run an athena query over a Feature Group in a Feature Store to retrieve its data. + It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. + Returns a dataframe with the data. + + Args: + region (str): region of the target feature store + feature_group_name (str): feature store name + query (str): query to run. By default, it will take the latest ingest with data that wasn't deleted. + If latest_ingestion is False it will take all the data in the feature group that wasn't + deleted. + athena_bucket (str): S3 bucket for running the query + role (str): role of the account used to extract data from feature store + session (str): session of SageMaker used to work with the feature store + event_time_feature_name (str): eventTimeId feature. Mandatory only if the latest ingestion is True + latest_ingestion (bool): if True it will get the data only from the latest ingestion. If False it + will take whatever is specified in the query, or if not specify it, it will + get all the data that wasn't deleted. + logger_level (int): logger level used by lib logging. + + Returns: + dataset (pandas.DataFrame): dataset with the data retrieved from feature group + """ + logger.setLevel(logger_level) + + if latest_ingestion: + if event_time_feature_name is not None: + query += str(f'AND {event_time_feature_name}=(SELECT MAX({event_time_feature_name}) FROM ' + + f'"sagemaker_featurestore"."{feature_group_name}")') + query += ';' + + if session is not None: + sagemaker_session = session + elif role is not None and region is not None: + sagemaker_session = get_session_from_role(role=role, region=region) + else: + exc = Exception('Argument Session or role and region must be specified.') + logger.exception(exc) + raise exc + + logger.info(f'Feature Group used: {feature_group_name}\n') + + fg = FeatureGroup(name=feature_group_name, + sagemaker_session=sagemaker_session) + + sample_query = fg.athena_query() + query_string = re.sub(r'#\{(table)\}', sample_query.table_name, query) + + logger.info(f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n") + + sample_query.run(query_string=query_string, + output_location=athena_bucket) + + sample_query.wait() + + # run Athena query. The output is loaded to a Pandas dataframe. + dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs) + + logger.info(f'Data shape retrieve from {feature_group_name}: {dataset.shape}') + + return dataset \ No newline at end of file diff --git a/src/sagemaker/feature_store/feature_definition.py b/src/sagemaker/feature_store/feature_definition.py index 5a91a9c512..b7c9aacda4 100644 --- a/src/sagemaker/feature_store/feature_definition.py +++ b/src/sagemaker/feature_store/feature_definition.py @@ -46,7 +46,7 @@ class FeatureDefinition(Config): This instantiates a Feature Definition object where FeatureDefinition is a subclass of Config. Attributes: - feature_name (str): The name of the feature + feature_group_name (str): The name of the feature feature_type (FeatureTypeEnum): The type of the feature """ diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 41bdcd764c..654f07e97a 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -123,9 +123,13 @@ def get_query_execution(self) -> Dict[str, Any]: query_execution_id=self._current_query_execution_id ) - def as_dataframe(self) -> DataFrame: + def as_dataframe(self, **pandas_read_csv_kwargs) -> DataFrame: """Download the result of the current query and load it into a DataFrame. + Args: + pandas_read_csv_kwargs: key arguments used for the method pandas.read_csv to be able to have a better + tuning on data. For more info about this methods visit: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html Returns: A pandas DataFrame contains the query result. """ @@ -146,7 +150,10 @@ def as_dataframe(self) -> DataFrame: query_execution_id=self._current_query_execution_id, filename=output_filename, ) - return pd.read_csv(output_filename, delimiter=",") + + # Assuring delimiter used by default + pandas_read_csv_kwargs.pop('delimiter', None) + return pd.read_csv(output_filename, delimiter=",", **pandas_read_csv_kwargs) @attr.s diff --git a/src/sagemaker/feature_store/inputs.py b/src/sagemaker/feature_store/inputs.py index 1f31caa4ae..b9c509b48e 100644 --- a/src/sagemaker/feature_store/inputs.py +++ b/src/sagemaker/feature_store/inputs.py @@ -190,7 +190,7 @@ class FeatureValue(Config): """FeatureValue for FeatureStore. Attributes: - feature_name (str): name of the Feature. + feature_group_name (str): name of the Feature. value_as_string (str): value of the Feature in string form. """ diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1d2e9fe5cb..97c6b05a90 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,12 +29,12 @@ from datetime import datetime import botocore +import boto3 from six.moves.urllib import parse -from sagemaker import deprecations +from sagemaker import deprecations, Session from sagemaker.session_settings import SessionSettings - ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" MAX_BUCKET_PATHS_COUNT = 5 S3_PREFIX = "s3://" @@ -83,7 +83,7 @@ def name_from_base(base, max_length=63, short=False): def unique_name_from_base(base, max_length=63): """Placeholder Docstring""" random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used. - unique = "%04x" % random.randrange(16**4) # 4-digit hex + unique = "%04x" % random.randrange(16 ** 4) # 4-digit hex ts = str(int(time.time())) available_length = max_length - 2 - len(ts) - len(unique) trimmed = base[:available_length] @@ -187,8 +187,8 @@ def secondary_training_status_changed(current_job_description, prev_job_descript """ current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") if ( - current_secondary_status_transitions is None - or len(current_secondary_status_transitions) == 0 + current_secondary_status_transitions is None + or len(current_secondary_status_transitions) == 0 ): return False @@ -201,7 +201,7 @@ def secondary_training_status_changed(current_job_description, prev_job_descript last_message = ( prev_job_secondary_status_transitions[-1]["StatusMessage"] if prev_job_secondary_status_transitions is not None - and len(prev_job_secondary_status_transitions) > 0 + and len(prev_job_secondary_status_transitions) > 0 else "" ) @@ -222,9 +222,9 @@ def secondary_training_status_message(job_description, prev_description): """ if ( - job_description is None - or job_description.get("SecondaryStatusTransitions") is None - or len(job_description.get("SecondaryStatusTransitions")) == 0 + job_description is None + or job_description.get("SecondaryStatusTransitions") is None + or len(job_description.get("SecondaryStatusTransitions")) == 0 ): return "" @@ -244,8 +244,8 @@ def secondary_training_status_message(job_description, prev_description): else: # Secondary status is changed we need to print all the entries. transitions_to_print = current_transitions[ - prev_transitions_num - len(current_transitions) : - ] + prev_transitions_num - len(current_transitions): + ] status_strs = [] for transition in transitions_to_print: @@ -308,7 +308,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3): if obj_sum.key.endswith("/"): continue obj = s3.Object(obj_sum.bucket_name, obj_sum.key) - s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/") + s3_relative_path = obj_sum.key[len(prefix):].lstrip("/") file_path = os.path.join(target, s3_relative_path) try: @@ -365,13 +365,13 @@ def _tmpdir(suffix="", prefix="tmp"): def repack_model( - inference_script, - source_directory, - dependencies, - model_uri, - repacked_model_uri, - sagemaker_session, - kms_key=None, + inference_script, + source_directory, + dependencies, + model_uri, + repacked_model_uri, + sagemaker_session, + kms_key=None, ): """Unpack model tarball and creates a new model tarball with the provided code script. @@ -458,7 +458,7 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): def _create_or_update_code_dir( - model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp + model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") @@ -554,9 +554,9 @@ def sts_regional_endpoint(region): def retries( - max_retry_count, - exception_message_prefix, - seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS, + max_retry_count, + exception_message_prefix, + seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS, ): """Retries until max retry count is reached. @@ -653,6 +653,37 @@ def _module_import_error(py_module, feature, extras): return error_msg.format(py_module, feature, extras) +def get_session_from_role(role: str, region: str): + boto_session = boto3.Session(region_name=region) + + sts = boto_session.client('sts', + region_name=region, + endpoint_url='https://sts.eu-west-1.amazonaws.com') + + metadata = sts.assume_role(RoleArn=role, + RoleSessionName='SagemakerExecution') + + access_key_id = metadata['Credentials']['AccessKeyId'] + secret_access_key = metadata['Credentials']['SecretAccessKey'] + session_token = metadata['Credentials']['SessionToken'] + + boto_session = boto3.session.Session(region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token) + + # Sessions + sagemaker_client = boto_session.client('sagemaker') + sagemaker_runtime = boto_session.client('sagemaker-runtime') + sagemaker_featurestore_runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') + sagemaker_session = Session(boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client) + + return sagemaker_session + + class DataConfig(abc.ABC): """Abstract base class for accessing data config hosted in AWS resources. @@ -672,10 +703,10 @@ class S3DataConfig(DataConfig): """This class extends the DataConfig class to fetch a data config file hosted on S3""" def __init__( - self, - sagemaker_session, - bucket_name, - prefix, + self, + sagemaker_session, + bucket_name, + prefix, ): """Initialize a ``S3DataConfig`` instance. diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 15c1db41ab..0d8da861e3 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -26,6 +26,8 @@ from sagemaker.feature_store.inputs import FeatureValue from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout +from sagemaker.feature_group_utils import get_feature_group_as_dataframe +from sagemaker.utils import get_session_from_role BUCKET_POLICY = { "Version": "2012-10-17", @@ -76,7 +78,7 @@ def feature_store_session(): @pytest.fixture def feature_group_name(): - return f"my-feature-group-{int(time.time() * 10**7)}" + return f"my-feature-group-{int(time.time() * 10 ** 7)}" @pytest.fixture @@ -147,10 +149,10 @@ def create_table_ddl(): def test_create_feature_store_online_only( - feature_store_session, - role, - feature_group_name, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -169,13 +171,13 @@ def test_create_feature_store_online_only( def test_create_feature_store( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, - record, - create_table_ddl, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, + record, + create_table_ddl, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -226,23 +228,23 @@ def test_create_feature_store( for is_na in nans.items(): assert is_na assert ( - create_table_ddl.format( - feature_group_name=feature_group_name, - region=feature_store_session.boto_session.region_name, - account=feature_store_session.account_id(), - resolved_output_s3_uri=resolved_output_s3_uri, - ) - == feature_group.as_hive_ddl() + create_table_ddl.format( + feature_group_name=feature_group_name, + region=feature_store_session.boto_session.region_name, + account=feature_store_session.account_id(), + resolved_output_s3_uri=resolved_output_s3_uri, + ) + == feature_group.as_hive_ddl() ) assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") def test_ingest_without_string_feature( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame_without_string, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame_without_string, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame_without_string) @@ -266,11 +268,11 @@ def test_ingest_without_string_feature( def test_ingest_multi_process( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -304,6 +306,71 @@ def _wait_for_feature_group_create(feature_group: FeatureGroup): print(f"FeatureGroup {feature_group.name} successfully created.") +def test_get_feature_group_with_role_region( + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + output = feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + feature_group.ingest( + data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True + ) + + dataset = get_feature_group_as_dataframe(feature_group_name=feature_group_name, + region=region_name, role=role, + event_time_feature_name="feature3", + latest_ingestion=True, + athena_bucket=f'{offline_store_s3_uri}/query') + + assert dataset.empty == False + +def test_get_feature_group_with_session( + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + output = feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + feature_group.ingest( + data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True + ) + + dataset = get_feature_group_as_dataframe(feature_group_name=feature_group_name, + session=feature_store_session, + event_time_feature_name="feature3", + latest_ingestion=True, + athena_bucket=f'{offline_store_s3_uri}/query') + + assert dataset.empty == False + + @contextmanager def cleanup_feature_group(feature_group: FeatureGroup): try: diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 4b6d3a39ae..a051f11d34 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -842,7 +842,7 @@ def test_default_monitor_monitoring_execution_interactions( ) constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations() - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" @pytest.mark.skipif( @@ -1406,7 +1406,7 @@ def test_byoc_monitor_monitoring_execution_interactions( ) constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations() - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def _wait_for_schedule_changes_to_apply(monitor): diff --git a/tests/integ/test_monitoring_files.py b/tests/integ/test_monitoring_files.py index 08dffd99c9..ff164f5587 100644 --- a/tests/integ/test_monitoring_files.py +++ b/tests/integ/test_monitoring_files.py @@ -315,7 +315,7 @@ def test_constraint_violations_object_creation_from_file_path_with_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_file_path_without_customizations( @@ -331,7 +331,7 @@ def test_constraint_violations_object_creation_from_file_path_without_customizat assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_string_with_customizations( @@ -350,7 +350,7 @@ def test_constraint_violations_object_creation_from_string_with_customizations( assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_string_without_customizations( @@ -366,7 +366,7 @@ def test_constraint_violations_object_creation_from_string_without_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_s3_uri_with_customizations( @@ -400,7 +400,7 @@ def test_constraint_violations_object_creation_from_s3_uri_with_customizations( assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_s3_uri_without_customizations( @@ -429,4 +429,4 @@ def test_constraint_violations_object_creation_from_s3_uri_without_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" From 5bd33c9a6001abd8ca719c7a5831bea4460dda77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Thu, 16 Jun 2022 17:49:15 +0200 Subject: [PATCH 002/526] added tests and kwargs for pandas.read_csv --- tests/integ/test_feature_store.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 0d8da861e3..8e0eb3757d 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -361,12 +361,14 @@ def test_get_feature_group_with_session( feature_group.ingest( data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True ) - + dataset = get_feature_group_as_dataframe(feature_group_name=feature_group_name, session=feature_store_session, event_time_feature_name="feature3", latest_ingestion=True, - athena_bucket=f'{offline_store_s3_uri}/query') + athena_bucket=f'{offline_store_s3_uri}/query', + low_memory=False) # Using kwargs to pass a parameter to + # pandas.read_csv assert dataset.empty == False From d0c582ada48a0050dfce033cda540cbdbe3e3546 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Thu, 16 Jun 2022 17:56:00 +0200 Subject: [PATCH 003/526] feature: added extra arguments to get feature group as dataframe --- src/sagemaker/feature_store/feature_group.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 41bdcd764c..fcaaec362c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -123,9 +123,13 @@ def get_query_execution(self) -> Dict[str, Any]: query_execution_id=self._current_query_execution_id ) - def as_dataframe(self) -> DataFrame: + def as_dataframe(self, **kwargs) -> DataFrame: """Download the result of the current query and load it into a DataFrame. + Args: + kwargs: key arguments used for the method pandas.read_csv to be able to have a better tuning on data. + For more info read https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + Returns: A pandas DataFrame contains the query result. """ @@ -146,7 +150,9 @@ def as_dataframe(self) -> DataFrame: query_execution_id=self._current_query_execution_id, filename=output_filename, ) - return pd.read_csv(output_filename, delimiter=",") + + kwargs.pop('delimiter', None) + return pd.read_csv(filepath_or_buffer=output_filename, delimiter=",", **kwargs) @attr.s From d53f2111b1e3c413159b63c7a1b9603c19d87342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Tue, 15 Nov 2022 12:27:07 +0100 Subject: [PATCH 004/526] Added more documentation --- src/sagemaker/feature_group_utils.py | 77 +++++++++++++++++++++++----- 1 file changed, 64 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 5f385bba45..68edcfe8d2 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -10,14 +10,16 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Utilities for working with FeatureGroup and FeatureStores. +""" +Utilities for working with FeatureGroups and FeatureStores. """ - import re import logging import json +from typing import Union +from pathlib import Path import pandas from pandas import DataFrame, Series, read_csv @@ -46,7 +48,8 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, feature_group_name (str): feature store name query (str): query to run. By default, it will take the latest ingest with data that wasn't deleted. If latest_ingestion is False it will take all the data in the feature group that wasn't - deleted. + deleted. It needs to use the keyword "#{table}" to refer to the table. e.g.: + 'SELECT * FROM "sagemaker_featurestore"."#{table}"' athena_bucket (str): S3 bucket for running the query role (str): role of the account used to extract data from feature store session (str): session of SageMaker used to work with the feature store @@ -98,29 +101,72 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, return dataset -def _format_column_names(data: pandas.DataFrame): + +def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: + """ + Module to format correctly the name of the columns of a DataFrame to later generate the features names + of a Feature Group + + Args: + data (pandas.DataFrame): dataframe used + + Returns: + pandas.DataFrame + """ data.rename(columns=lambda x: x.replace(' ', '_').replace('.', '').lower()[:62], inplace=True) return data -def _cast_object_to_string(data_frame: pandas.DataFrame): +def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: + """ + Method to convert 'object' and 'O' column dtypes of a pandas.DataFrame to a valid string type recognized + by Feature Groups. + + Args: + data_frame: dataframe used + + Returns: + pandas.DataFrame + """ for label in data_frame.select_dtypes(['object', 'O']).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") return data_frame -def get_fg_schema(dataframe_or_path, record_id: str, fg_name: str, role: str, region: str, saving_file_path: str = '', - event_id: str = 'data_as_of_date', mode: str = 'display', - logger_level: int = logging.INFO, - **pandas_read_csv_kwargs): + +def get_fg_schema(dataframe_or_path: Union[str, Path, pandas.DataFrame], + role: str, region: str, + mode: str = 'display', record_id: str = '@index', + event_id: str = 'data_as_of_date', + saving_file_path: str = '', verbose: bool = False, + **pandas_read_csv_kwargs) -> None: """ + Method to generate the schema of a Feature Group from a pandas.DataFrame. It has two modes (`mode`): + - display: the schema is printed on the display + - make_file: it generates a file with the schema inside. Recommended if it has a lot of features. Then + argument `saving_file_path` must be specified. + + Args: + dataframe_or_path (str, Path, pandas.DataFrame) : pandas.DataFrame or path to the data + mode (str) : it changes how the output is displayed or stored, as explained before. By default, + mode='display', and the other mode is `make_file`. + verbose (bool) : True for displaying messages, False for silent method. + record_id (str, '@index'): (Optional) Feature identifier of the rows. If specified each value of that feature + has to be unique. If not specified or record_id='@index', then it will create + a new feature from the index of the pandas.DataFrame. + event_id (str) : (Optional) Feature with the time of the creation of data rows. If not specified it + will create one with the current time called `data_as_of_date` + role (str) : role used to get the session + region (str) : region used to get the session + saving_file_path (str) : required if mode='make_file', file path to save the output. Returns: Save text into a file or displays the feature group schema by teh screen """ MODE = ['display', 'make_file'] - logger.setLevel(logger_level) - + logger.setLevel(logging.WARNING) + if verbose: + logger.setLevel(logging.INFO) mode = mode.lower() if mode not in MODE: @@ -141,10 +187,14 @@ def get_fg_schema(dataframe_or_path, record_id: str, fg_name: str, role: str, re logger.exception(exc) raise exc - # Formating cols + # Formatting cols data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) + if record_id == '@index': + record_id = 'index' + data[record_id] = data.index + lg_uniq = len(data[record_id].unique()) lg_id = len(data[record_id]) @@ -156,7 +206,7 @@ def get_fg_schema(dataframe_or_path, record_id: str, fg_name: str, role: str, re session = get_session_from_role(role=role, region=region) feature_group = FeatureGroup( - name=fg_name, sagemaker_session=session + name='temporalFG', sagemaker_session=session ) if event_id not in data.columns: @@ -186,6 +236,7 @@ def get_fg_schema(dataframe_or_path, record_id: str, fg_name: str, role: str, re with open(saving_file_path, 'w') as f: f.write(json.dumps(def_list)) f.close() + logger.info('Finished!.') else: exc = Exception(str(f'Parameter saving_file_path mandatory if mode {MODE[1]} is specified.')) logger.exception(exc) From ff8c62a4fa6310313e85cc3fc34bbb3505418187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Tue, 15 Nov 2022 12:29:33 +0100 Subject: [PATCH 005/526] Added more documentation --- src/sagemaker/feature_group_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 68edcfe8d2..6dd8267d34 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -35,7 +35,7 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, + 'is_deleted=False'), role: str = None, region: str = None, session=None, event_time_feature_name: str = None, latest_ingestion: bool = True, - logger_level: int = logging.INFO, + verbose: bool = True, **pandas_read_csv_kwargs) -> DataFrame: """ Description: @@ -57,12 +57,15 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, latest_ingestion (bool): if True it will get the data only from the latest ingestion. If False it will take whatever is specified in the query, or if not specify it, it will get all the data that wasn't deleted. - logger_level (int): logger level used by lib logging. + verbose (bool): if True show messages, if False is silent. Returns: dataset (pandas.DataFrame): dataset with the data retrieved from feature group """ - logger.setLevel(logger_level) + + logger.setLevel(logging.WARNING) + if verbose: + logger.setLevel(logging.INFO) if latest_ingestion: if event_time_feature_name is not None: From 3cfb6b16eaf749b14cf6ae9af294fdb9f1e1feac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Fri, 18 Nov 2022 15:37:58 +0100 Subject: [PATCH 006/526] fix: circular import --- src/sagemaker/utils.py | 34 ++-------------------------------- tests/conftest.py | 3 ++- 2 files changed, 4 insertions(+), 33 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 36b28ec9c8..4f48bc0af2 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -33,7 +33,8 @@ import boto3 from six.moves.urllib import parse -from sagemaker import deprecations, Session +from sagemaker import deprecations + from sagemaker.session_settings import SessionSettings from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string @@ -664,37 +665,6 @@ def _module_import_error(py_module, feature, extras): return error_msg.format(py_module, feature, extras) -def get_session_from_role(role: str, region: str): - boto_session = boto3.Session(region_name=region) - - sts = boto_session.client('sts', - region_name=region, - endpoint_url='https://sts.eu-west-1.amazonaws.com') - - metadata = sts.assume_role(RoleArn=role, - RoleSessionName='SagemakerExecution') - - access_key_id = metadata['Credentials']['AccessKeyId'] - secret_access_key = metadata['Credentials']['SecretAccessKey'] - session_token = metadata['Credentials']['SessionToken'] - - boto_session = boto3.session.Session(region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token) - - # Sessions - sagemaker_client = boto_session.client('sagemaker') - sagemaker_runtime = boto_session.client('sagemaker-runtime') - sagemaker_featurestore_runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') - sagemaker_session = Session(boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=sagemaker_runtime, - sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client) - - return sagemaker_session - - class DataConfig(abc.ABC): """Abstract base class for accessing data config hosted in AWS resources. diff --git a/tests/conftest.py b/tests/conftest.py index e92d98112b..ec7918840e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,8 @@ from botocore.config import Config from packaging.version import Version -from sagemaker import Session, image_uris, utils +from sagemaker.session import Session +from sagemaker import image_uris, utils from sagemaker.local import LocalSession from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession From db700cb5dc02b09cf8ee96724be5a974f1674da6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Fri, 18 Nov 2022 15:39:03 +0100 Subject: [PATCH 007/526] feature: utils to prepare, create and extract feature groups --- src/sagemaker/feature_group_utils.py | 143 ++++++++++-------- tests/integ/test_feature_store.py | 27 ++-- .../feature_store/test_feature_group_utils.py | 84 ++++++++++ 3 files changed, 181 insertions(+), 73 deletions(-) create mode 100644 tests/unit/sagemaker/feature_store/test_feature_group_utils.py diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 6dd8267d34..e56bb383fd 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -17,19 +17,63 @@ """ import re import logging -import json + from typing import Union from pathlib import Path import pandas +import boto3 from pandas import DataFrame, Series, read_csv from sagemaker.feature_store.feature_group import FeatureGroup -from sagemaker.utils import get_session_from_role +from sagemaker.session import Session logger = logging.getLogger(__name__) +def _get_session_from_role(role: str, region: str): + """ + Method use to get the sagemaker session from a role and a region. Helpful in case it's + invoke from a session with a role without permission it can assume another role temporarily + to perform certain taks. + + Args: + role: role name + region: region name + + Returns: + + """ + boto_session = boto3.Session(region_name=region) + + sts = boto_session.client('sts', + region_name=region, + endpoint_url='https://sts.eu-west-1.amazonaws.com') + + metadata = sts.assume_role(RoleArn=role, + RoleSessionName='SagemakerExecution') + + access_key_id = metadata['Credentials']['AccessKeyId'] + secret_access_key = metadata['Credentials']['SecretAccessKey'] + session_token = metadata['Credentials']['SessionToken'] + + boto_session = boto3.session.Session(region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token) + + # Sessions + sagemaker_client = boto_session.client('sagemaker') + sagemaker_runtime = boto_session.client('sagemaker-runtime') + sagemaker_featurestore_runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') + sagemaker_session = Session(boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client) + + return sagemaker_session + + def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, query: str = str('SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE ' + 'is_deleted=False'), @@ -76,7 +120,7 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, if session is not None: sagemaker_session = session elif role is not None and region is not None: - sagemaker_session = get_session_from_role(role=role, region=region) + sagemaker_session = _get_session_from_role(role=role, region=region) else: exc = Exception('Argument Session or role and region must be specified.') logger.exception(exc) @@ -136,47 +180,40 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: return data_frame -def get_fg_schema(dataframe_or_path: Union[str, Path, pandas.DataFrame], - role: str, region: str, - mode: str = 'display', record_id: str = '@index', - event_id: str = 'data_as_of_date', - saving_file_path: str = '', verbose: bool = False, - **pandas_read_csv_kwargs) -> None: +def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, region: str = None, session=None, + record_id: str = 'record_id', event_id: str = 'data_as_of_date', + verbose: bool = False, + **pandas_read_csv_kwargs) -> FeatureGroup: """ - Method to generate the schema of a Feature Group from a pandas.DataFrame. It has two modes (`mode`): - - display: the schema is printed on the display - - make_file: it generates a file with the schema inside. Recommended if it has a lot of features. Then - argument `saving_file_path` must be specified. + Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to + a file with proper dtypes, feature names and mandatory features (record_id, event_id). + It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. + If record_id or event_id are not specified it will create ones by default with the names + Args: + feature_group_name (str): feature group name dataframe_or_path (str, Path, pandas.DataFrame) : pandas.DataFrame or path to the data - mode (str) : it changes how the output is displayed or stored, as explained before. By default, - mode='display', and the other mode is `make_file`. verbose (bool) : True for displaying messages, False for silent method. - record_id (str, '@index'): (Optional) Feature identifier of the rows. If specified each value of that feature - has to be unique. If not specified or record_id='@index', then it will create + record_id (str, 'record_id'): (Optional) Feature identifier of the rows. If specified each value of that feature + has to be unique. If not specified or record_id='record_id', then it will create a new feature from the index of the pandas.DataFrame. event_id (str) : (Optional) Feature with the time of the creation of data rows. If not specified it will create one with the current time called `data_as_of_date` - role (str) : role used to get the session - region (str) : region used to get the session - saving_file_path (str) : required if mode='make_file', file path to save the output. + role (str) : role used to get the session. + region (str) : region used to get the session. + session (str): session of SageMaker used to work with the feature store Returns: - Save text into a file or displays the feature group schema by teh screen + FeatureGroup: FG prepared with all the methods and definitions properly defined """ - MODE = ['display', 'make_file'] logger.setLevel(logging.WARNING) if verbose: logger.setLevel(logging.INFO) - mode = mode.lower() - if mode not in MODE: - exc = Exception(f'Invalid value {mode} for parameter mode.\nMode must be in {MODE}') - logger.exception(exc) - raise exc - from sagemaker.feature_store.feature_group import FeatureGroup if isinstance(dataframe_or_path, DataFrame): @@ -194,8 +231,7 @@ def get_fg_schema(dataframe_or_path: Union[str, Path, pandas.DataFrame], data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) - if record_id == '@index': - record_id = 'index' + if record_id == 'record_id' and record_id not in data.columns: data[record_id] = data.index lg_uniq = len(data[record_id].unique()) @@ -207,40 +243,25 @@ def get_fg_schema(dataframe_or_path: Union[str, Path, pandas.DataFrame], logger.exception(exc) raise exc - session = get_session_from_role(role=role, region=region) - feature_group = FeatureGroup( - name='temporalFG', sagemaker_session=session - ) - if event_id not in data.columns: import time current_time_sec = int(round(time.time())) data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") - definitions = feature_group.load_feature_definitions(data_frame=data) - - def_list = [] - for ele in definitions: - def_list.append({'FeatureName': ele.feature_name, 'FeatureType': ele.feature_type.name}) - - if mode == MODE[0]: # display - logger.info('[') - for ele in def_list: - _to_print = json.dumps(ele) - if ele != def_list[-1]: - _to_print += ',' - - logger.info(f'{_to_print}') - logger.info(']') - elif mode == MODE[1]: # make_file - if saving_file_path: - logger.info(f'Saving schema to {saving_file_path}') - with open(saving_file_path, 'w') as f: - f.write(json.dumps(def_list)) - f.close() - logger.info('Finished!.') - else: - exc = Exception(str(f'Parameter saving_file_path mandatory if mode {MODE[1]} is specified.')) - logger.exception(exc) - raise exc + if session is not None: + sagemaker_session = session + elif role is not None and region is not None: + sagemaker_session = _get_session_from_role(role=role, region=region) + else: + exc = Exception('Argument Session or role and region must be specified.') + logger.exception(exc) + raise exc + + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=sagemaker_session + ) + + feature_group.load_feature_definitions(data_frame=data) + + return feature_group diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 622ccdc220..baa4d07935 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -241,11 +241,11 @@ def test_create_feature_store( def test_update_feature_group( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -269,11 +269,11 @@ def test_update_feature_group( def test_feature_metadata( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -427,6 +427,8 @@ def test_get_feature_group_with_role_region( athena_bucket=f'{offline_store_s3_uri}/query') assert dataset.empty == False + assert isinstance(dataset, DataFrame) + def test_get_feature_group_with_session( feature_store_session, @@ -457,10 +459,11 @@ def test_get_feature_group_with_session( event_time_feature_name="feature3", latest_ingestion=True, athena_bucket=f'{offline_store_s3_uri}/query', - low_memory=False) # Using kwargs to pass a parameter to - # pandas.read_csv + low_memory=False) # Using kwargs to pass a parameter to + # pandas.read_csv assert dataset.empty == False + assert isinstance(dataset, DataFrame) @contextmanager diff --git a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py new file mode 100644 index 0000000000..124a3715ee --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pandas as pd +import pytest +from mock import Mock + +from sagemaker.feature_group_utils import _cast_object_to_string, prepare_fg_from_dataframe_or_file +from sagemaker.feature_store.feature_definition import ( + FeatureTypeEnum, +) +from sagemaker.feature_store.feature_group import ( + FeatureGroup, +) + + +class PicklableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +def test_convert_unsupported_types_to_supported(sagemaker_session_mock): + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "object": pd.Series(["f1"], dtype="object"), + } + ) + # Converting object or O type to string + df = _cast_object_to_string(data_frame=df) + + feature_definitions = feature_group.load_feature_definitions(data_frame=df) + types = [fd.feature_type for fd in feature_definitions] + + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + ] + + +def test_prepare_fg_from_dataframe(sagemaker_session_mock): + very_long_name = 'long'*20 + df = pd.DataFrame( + { + "space feature": pd.Series([2.0], dtype="float64"), + "dot.feature": pd.Series([2], dtype="int64"), + very_long_name: pd.Series(["f1"], dtype="string"), + } + ) + + feature_group = prepare_fg_from_dataframe_or_file(dataframe_or_path=df, session=sagemaker_session_mock, + feature_group_name='testFG') + + names = [fd.feature_name for fd in feature_group.feature_definitions] + types = [fd.feature_type for fd in feature_group.feature_definitions] + + assert names == ["space_feature", "dotfeature", very_long_name[:62], "index", "data_as_of_date"] + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.FRACTIONAL, + ] From aad7888409fd08f2ae858bb4816287a3eace79cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Fri, 18 Nov 2022 15:56:21 +0100 Subject: [PATCH 008/526] fix: refactoring changed name of many docstrings --- src/sagemaker/feature_store/feature_definition.py | 2 +- src/sagemaker/feature_store/inputs.py | 2 +- tests/integ/test_model_monitor.py | 4 ++-- tests/integ/test_monitoring_files.py | 12 ++++++------ .../feature_store/test_feature_group_utils.py | 15 +++++++++++---- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/feature_store/feature_definition.py b/src/sagemaker/feature_store/feature_definition.py index b7c9aacda4..5a91a9c512 100644 --- a/src/sagemaker/feature_store/feature_definition.py +++ b/src/sagemaker/feature_store/feature_definition.py @@ -46,7 +46,7 @@ class FeatureDefinition(Config): This instantiates a Feature Definition object where FeatureDefinition is a subclass of Config. Attributes: - feature_group_name (str): The name of the feature + feature_name (str): The name of the feature feature_type (FeatureTypeEnum): The type of the feature """ diff --git a/src/sagemaker/feature_store/inputs.py b/src/sagemaker/feature_store/inputs.py index 561d760eda..75cb99b5f6 100644 --- a/src/sagemaker/feature_store/inputs.py +++ b/src/sagemaker/feature_store/inputs.py @@ -190,7 +190,7 @@ class FeatureValue(Config): """FeatureValue for FeatureStore. Attributes: - feature_group_name (str): name of the Feature. + feature_name (str): name of the Feature. value_as_string (str): value of the Feature in string form. """ diff --git a/tests/integ/test_model_monitor.py b/tests/integ/test_model_monitor.py index 5cb2ac0a83..f6d5ee88ed 100644 --- a/tests/integ/test_model_monitor.py +++ b/tests/integ/test_model_monitor.py @@ -922,7 +922,7 @@ def test_default_monitor_monitoring_execution_interactions( ) constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations() - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" @pytest.mark.skipif( @@ -1486,7 +1486,7 @@ def test_byoc_monitor_monitoring_execution_interactions( ) constraint_violations = my_attached_monitor.latest_monitoring_constraint_violations() - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def _wait_for_schedule_changes_to_apply(monitor): diff --git a/tests/integ/test_monitoring_files.py b/tests/integ/test_monitoring_files.py index ff164f5587..08dffd99c9 100644 --- a/tests/integ/test_monitoring_files.py +++ b/tests/integ/test_monitoring_files.py @@ -315,7 +315,7 @@ def test_constraint_violations_object_creation_from_file_path_with_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_file_path_without_customizations( @@ -331,7 +331,7 @@ def test_constraint_violations_object_creation_from_file_path_without_customizat assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_string_with_customizations( @@ -350,7 +350,7 @@ def test_constraint_violations_object_creation_from_string_with_customizations( assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_string_without_customizations( @@ -366,7 +366,7 @@ def test_constraint_violations_object_creation_from_string_without_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_s3_uri_with_customizations( @@ -400,7 +400,7 @@ def test_constraint_violations_object_creation_from_s3_uri_with_customizations( assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" def test_constraint_violations_object_creation_from_s3_uri_without_customizations( @@ -429,4 +429,4 @@ def test_constraint_violations_object_creation_from_s3_uri_without_customization assert constraint_violations.file_s3_uri.startswith("s3://") assert constraint_violations.file_s3_uri.endswith("constraint_violations.json") - assert constraint_violations.body_dict["violations"][0]["feature_group_name"] == "store_and_fwd_flag" + assert constraint_violations.body_dict["violations"][0]["feature_name"] == "store_and_fwd_flag" diff --git a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py index 124a3715ee..632c4a0f26 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py @@ -59,7 +59,7 @@ def test_convert_unsupported_types_to_supported(sagemaker_session_mock): def test_prepare_fg_from_dataframe(sagemaker_session_mock): - very_long_name = 'long'*20 + very_long_name = "long" * 20 df = pd.DataFrame( { "space feature": pd.Series([2.0], dtype="float64"), @@ -68,13 +68,20 @@ def test_prepare_fg_from_dataframe(sagemaker_session_mock): } ) - feature_group = prepare_fg_from_dataframe_or_file(dataframe_or_path=df, session=sagemaker_session_mock, - feature_group_name='testFG') + feature_group = prepare_fg_from_dataframe_or_file( + dataframe_or_path=df, session=sagemaker_session_mock, feature_group_name="testFG" + ) names = [fd.feature_name for fd in feature_group.feature_definitions] types = [fd.feature_type for fd in feature_group.feature_definitions] - assert names == ["space_feature", "dotfeature", very_long_name[:62], "index", "data_as_of_date"] + assert names == [ + "space_feature", + "dotfeature", + very_long_name[:62], + "record_id", + "data_as_of_date", + ] assert types == [ FeatureTypeEnum.FRACTIONAL, FeatureTypeEnum.INTEGRAL, From 70a2990afa0fc218e101a4e6fb9bb81b79e8ae6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Juan=20Pe=C3=B1as?= Date: Mon, 21 Nov 2022 12:03:27 +0100 Subject: [PATCH 009/526] change: linux line separator --- .../text/machine_translation_hugging_face.rst | 36 + .../text/question_answering_pytorch.rst | 101 + ...tence_pair_classification_hugging_face.rst | 116 + ...entence_pair_classification_tensorflow.rst | 61 + .../text/text_classification_tensorflow.rst | 201 ++ .../text/text_generation_hugging_face.rst | 36 + .../text/text_summarization_hugging_face.rst | 51 + .../vision/image_classification_pytorch.rst | 146 + .../image_classification_tensorflow.rst | 311 +++ .../vision/image_embedding_tensorflow.rst | 271 ++ .../vision/instance_segmentation_mxnet.rst | 31 + .../vision/object_detection_mxnet.rst | 96 + .../vision/object_detection_pytorch.rst | 16 + .../vision/object_detection_tensorflow.rst | 191 ++ .../vision/semantic_segmentation_mxnet.rst | 31 + .../text_embedding_tensorflow_mxnet.rst | 181 ++ doc/doc_utils/pretrainedmodels.rst | 2396 +++++++++++++++++ doc/make.bat | 72 +- src/sagemaker/feature_group_utils.py | 136 +- src/sagemaker/feature_store/feature_group.py | 2 +- src/sagemaker/utils.py | 50 +- tests/data/upload_data_tests/file1.py | 6 +- tests/data/upload_data_tests/file2.py | 6 +- .../upload_data_tests/nested_dir/file3.py | 6 +- .../upload_data_tests/nested_dir/file4.py | 6 +- tests/data/workflow/dummy_data.csv | 12 +- tests/integ/test_feature_store.py | 123 +- 27 files changed, 4496 insertions(+), 195 deletions(-) diff --git a/doc/algorithms/text/machine_translation_hugging_face.rst b/doc/algorithms/text/machine_translation_hugging_face.rst index d533d0e64d..622ab4bdac 100644 --- a/doc/algorithms/text/machine_translation_hugging_face.rst +++ b/doc/algorithms/text/machine_translation_hugging_face.rst @@ -8,3 +8,39 @@ This is a supervised machine translation algorithm which supports many pre-train demonstrates how to use the Sagemaker Python SDK for Machine Translation for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `. + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - huggingface-translation-opus-mt-en-es + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-translation-opus-mt-en-vi + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-translation-t5-base + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-translation-t5-large + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-translation-t5-small + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ diff --git a/doc/algorithms/text/question_answering_pytorch.rst b/doc/algorithms/text/question_answering_pytorch.rst index 9d9d74ccb1..4ad2205f40 100644 --- a/doc/algorithms/text/question_answering_pytorch.rst +++ b/doc/algorithms/text/question_answering_pytorch.rst @@ -7,3 +7,104 @@ This is a supervised question answering algorithm which supports fine-tuning of demonstrates how to use the Sagemaker Python SDK for Question Answering for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - pytorch-eqa-bert-base-cased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-base-multilingual-cased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-base-multilingual-uncased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-base-uncased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-cased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-cased-whole-word-masking + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-cased-whole-word-masking-finetuned-squad + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-uncased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-uncased-whole-word-masking + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-bert-large-uncased-whole-word-masking-finetuned-squad + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-distilbert-base-cased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-distilbert-base-multilingual-cased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-distilbert-base-uncased + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-distilroberta-base + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-roberta-base + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-roberta-base-openai-detector + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-roberta-large + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-eqa-roberta-large-openai-detector + - True + - 1.2.1 + - 2.75.0 + - `Pytorch Hub `__ diff --git a/doc/algorithms/text/sentence_pair_classification_hugging_face.rst b/doc/algorithms/text/sentence_pair_classification_hugging_face.rst index 2892b9d516..40d09854ab 100644 --- a/doc/algorithms/text/sentence_pair_classification_hugging_face.rst +++ b/doc/algorithms/text/sentence_pair_classification_hugging_face.rst @@ -7,3 +7,119 @@ This is a supervised sentence pair classification algorithm which supports fine- demonstrates how to use the Sagemaker Python SDK for Sentence Pair Classification for using these algorithms. For detailed documentation please refer `Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `__ + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - huggingface-spc-bert-base-cased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-base-multilingual-uncased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-large-cased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-large-cased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-large-uncased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-bert-large-uncased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-distilbert-base-cased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-distilbert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-distilbert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-distilroberta-base + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-roberta-base + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-roberta-base-openai-detector + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-roberta-large + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-roberta-large-openai-detector + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-xlm-clm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-xlm-mlm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-xlm-mlm-enro-1024 + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-xlm-mlm-tlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-spc-xlm-mlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - `HuggingFace `__ diff --git a/doc/algorithms/text/sentence_pair_classification_tensorflow.rst b/doc/algorithms/text/sentence_pair_classification_tensorflow.rst index 80264e84f3..70d15aed1c 100644 --- a/doc/algorithms/text/sentence_pair_classification_tensorflow.rst +++ b/doc/algorithms/text/sentence_pair_classification_tensorflow.rst @@ -7,3 +7,64 @@ This is a supervised sentence pair classification algorithm which supports fine- demonstrates how to use the Sagemaker Python SDK for Sentence Pair Classification for using these algorithms. For detailed documentation please refer `Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `__ + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - tensorflow-spc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-bert-en-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-electra-base-1 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-electra-small-1 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-experts-bert-pubmed-1 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-spc-experts-bert-wiki-books-1 + - True + - 1.2.3 + - 2.75.0 + - `Tensorflow Hub `__ diff --git a/doc/algorithms/text/text_classification_tensorflow.rst b/doc/algorithms/text/text_classification_tensorflow.rst index c60a5b3e1c..891cfa7bbc 100644 --- a/doc/algorithms/text/text_classification_tensorflow.rst +++ b/doc/algorithms/text/text_classification_tensorflow.rst @@ -7,3 +7,204 @@ This is a supervised text classification algorithm which supports fine-tuning of demonstrates how to use the Sagemaker Python SDK for Text Classification for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - tensorflow-tc-albert-en-base + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-cased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-uncased-L-24-H-1024-A-16-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-electra-base-1 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-electra-small-1 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-experts-bert-pubmed-1 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-experts-bert-wiki-books-1 + - True + - 2.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-talking-heads-base + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-tc-talking-heads-large + - True + - 1.0.1 + - 2.80.0 + - `Tensorflow Hub `__ diff --git a/doc/algorithms/text/text_generation_hugging_face.rst b/doc/algorithms/text/text_generation_hugging_face.rst index 30fae26196..d2c6e3ac52 100644 --- a/doc/algorithms/text/text_generation_hugging_face.rst +++ b/doc/algorithms/text/text_generation_hugging_face.rst @@ -7,3 +7,39 @@ This is a supervised text generation algorithm which supports many pre-trained m demonstrates how to use the Sagemaker Python SDK for Text Generation for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - huggingface-textgeneration-bloom-1b1 + - False + - 1.0.1 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-textgeneration-bloom-1b7 + - False + - 1.0.1 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-textgeneration-bloom-560m + - False + - 1.0.1 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-textgeneration-distilgpt2 + - False + - 1.2.1 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-textgeneration-gpt2 + - False + - 1.2.1 + - 2.75.0 + - `HuggingFace `__ diff --git a/doc/algorithms/text/text_summarization_hugging_face.rst b/doc/algorithms/text/text_summarization_hugging_face.rst index 206c880ba3..1a0fecc8ae 100644 --- a/doc/algorithms/text/text_summarization_hugging_face.rst +++ b/doc/algorithms/text/text_summarization_hugging_face.rst @@ -7,3 +7,54 @@ This is a supervised text summarization algorithm which supports many pre-traine demonstrates how to use the Sagemaker Python SDK for Text Summarization for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - huggingface-summarization-bart-large-cnn-samsum + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-bert-small2bert-small-finetuned-cnn-daily-mail-summarization + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-bigbird-pegasus-large-arxiv + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-bigbird-pegasus-large-pubmed + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-distilbart-cnn-12-6 + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-distilbart-cnn-6-6 + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-distilbart-xsum-1-1 + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ + * - huggingface-summarization-distilbart-xsum-12-3 + - False + - 1.1.0 + - 2.75.0 + - `HuggingFace `__ diff --git a/doc/algorithms/vision/image_classification_pytorch.rst b/doc/algorithms/vision/image_classification_pytorch.rst index 3c154c6cfe..f1c04bb758 100644 --- a/doc/algorithms/vision/image_classification_pytorch.rst +++ b/doc/algorithms/vision/image_classification_pytorch.rst @@ -7,3 +7,149 @@ This is a supervised image clasification algorithm which supports fine-tuning of demonstrates how to use the Sagemaker Python SDK for Image Classification for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - pytorch-ic-alexnet + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-densenet121 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-densenet161 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-densenet169 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-densenet201 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-googlenet + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-mobilenet-v2 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnet101 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnet152 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnet18 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnet34 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnet50 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnext101-32x8d + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-resnext50-32x4d + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-shufflenet-v2-x1-0 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-squeezenet1-0 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-squeezenet1-1 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg11 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg11-bn + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg13 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg13-bn + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg16 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg16-bn + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg19 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-vgg19-bn + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-wide-resnet101-2 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ + * - pytorch-ic-wide-resnet50-2 + - True + - 2.2.4 + - 2.75.0 + - `Pytorch Hub `__ diff --git a/doc/algorithms/vision/image_classification_tensorflow.rst b/doc/algorithms/vision/image_classification_tensorflow.rst index e49820ee50..80be3fa4ed 100644 --- a/doc/algorithms/vision/image_classification_tensorflow.rst +++ b/doc/algorithms/vision/image_classification_tensorflow.rst @@ -7,3 +7,314 @@ This is a supervised image clasification algorithm which supports fine-tuning of demonstrates how to use the Sagemaker Python SDK for Image Classification for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r101x1-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r101x3-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r50x1-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-m-r50x3-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-s-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-s-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-bit-s-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b0-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b1-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b2-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b3-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b4-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b5-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b6-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-b7-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-lite0-classification-2 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-lite1-classification-2 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-lite2-classification-2 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-lite3-classification-2 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-efficientnet-lite4-classification-2 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-inception-resnet-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-inception-v1-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-inception-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-inception-v3-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-025-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-025-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-025-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-025-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-050-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-050-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-050-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-050-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-075-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-075-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-075-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-100-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-100-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-100-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v1-100-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-035-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-050-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-075-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-130-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-mobilenet-v2-140-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v1-101-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v1-152-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v1-50-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v2-101-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v2-152-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-imagenet-resnet-v2-50-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-resnet-50-classification-1 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-tf2-preview-inception-v3-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-ic-tf2-preview-mobilenet-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - `Tensorflow Hub `__ diff --git a/doc/algorithms/vision/image_embedding_tensorflow.rst b/doc/algorithms/vision/image_embedding_tensorflow.rst index 0938377354..4d7941bcea 100644 --- a/doc/algorithms/vision/image_embedding_tensorflow.rst +++ b/doc/algorithms/vision/image_embedding_tensorflow.rst @@ -7,3 +7,274 @@ This is a supervised image embedding algorithm which supports many pre-trained m demonstrates how to use the Sagemaker Python SDK for Image Embedding for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - tensorflow-icembedding-bit-m-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-m-r101x3-imagenet21k-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-m-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-m-r50x3-imagenet21k-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-s-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-s-r101x3-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-s-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-bit-s-r50x3-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-b0-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-b1-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-b2-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-b3-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-b6-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-lite0-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-lite1-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-lite2-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-lite3-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-efficientnet-lite4-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-inception-v1-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-inception-v2-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-inception-v3-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-035-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-050-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-075-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-100-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-130-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-mobilenet-v2-140-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v1-101-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v1-152-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v1-50-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v2-101-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v2-152-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-imagenet-resnet-v2-50-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-resnet-50-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-tf2-preview-inception-v3-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-icembedding-tf2-preview-mobilenet-v2-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ diff --git a/doc/algorithms/vision/instance_segmentation_mxnet.rst b/doc/algorithms/vision/instance_segmentation_mxnet.rst index a38611bc9a..b698ffb13a 100644 --- a/doc/algorithms/vision/instance_segmentation_mxnet.rst +++ b/doc/algorithms/vision/instance_segmentation_mxnet.rst @@ -7,3 +7,34 @@ This is a supervised image segmentation algorithm which supports many pre-train demonstrates how to use the Sagemaker Python SDK for Image Segmentation for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - mxnet-is-mask-rcnn-fpn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-is-mask-rcnn-fpn-resnet18-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-is-mask-rcnn-fpn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-is-mask-rcnn-resnet18-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ diff --git a/doc/algorithms/vision/object_detection_mxnet.rst b/doc/algorithms/vision/object_detection_mxnet.rst index 9ce52f992b..052ad100e1 100644 --- a/doc/algorithms/vision/object_detection_mxnet.rst +++ b/doc/algorithms/vision/object_detection_mxnet.rst @@ -7,3 +7,99 @@ This is a supervised object detection algorithm which supports fine-tuning of ma demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - mxnet-od-faster-rcnn-fpn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-faster-rcnn-fpn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-faster-rcnn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-faster-rcnn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-faster-rcnn-resnet50-v1b-voc + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-300-vgg16-atrous-coco + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-300-vgg16-atrous-voc + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-mobilenet1-0-coco + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-mobilenet1-0-voc + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-resnet50-v1-coco + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-resnet50-v1-voc + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-vgg16-atrous-coco + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-ssd-512-vgg16-atrous-voc + - True + - 1.3.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-yolo3-darknet53-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-yolo3-darknet53-voc + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-yolo3-mobilenet1-0-coco + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-od-yolo3-mobilenet1-0-voc + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ diff --git a/doc/algorithms/vision/object_detection_pytorch.rst b/doc/algorithms/vision/object_detection_pytorch.rst index aa703e74b5..b107e1f9c2 100644 --- a/doc/algorithms/vision/object_detection_pytorch.rst +++ b/doc/algorithms/vision/object_detection_pytorch.rst @@ -7,3 +7,19 @@ This is a supervised object detection algorithm which supports fine-tuning of ma demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - pytorch-od-nvidia-ssd + - False + - 1.0.2 + - 2.75.0 + - `Pytorch Hub `__ diff --git a/doc/algorithms/vision/object_detection_tensorflow.rst b/doc/algorithms/vision/object_detection_tensorflow.rst index 2536322847..58bbe85593 100644 --- a/doc/algorithms/vision/object_detection_tensorflow.rst +++ b/doc/algorithms/vision/object_detection_tensorflow.rst @@ -7,3 +7,194 @@ This is a supervised object detection algorithm which supports fine-tuning of ma demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - tensorflow-od-centernet-hourglass-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-hourglass-1024x1024-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-hourglass-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-hourglass-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-resnet101v1-fpn-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-resnet50v2-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-centernet-resnet50v2-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d0-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d1-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d2-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d3-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d4-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-efficientdet-d5-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-inception-resnet-v2-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-inception-resnet-v2-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet101-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet101-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet101-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet152-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet152-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet152-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet50-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet50-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-faster-rcnn-resnet50-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet101-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet101-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet152-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet152-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet50-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-retinanet-resnet50-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-ssd-mobilenet-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-ssd-mobilenet-v2-2 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-320x320-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - `Tensorflow Hub `__ diff --git a/doc/algorithms/vision/semantic_segmentation_mxnet.rst b/doc/algorithms/vision/semantic_segmentation_mxnet.rst index b0c60cd560..5f1ce6bf67 100644 --- a/doc/algorithms/vision/semantic_segmentation_mxnet.rst +++ b/doc/algorithms/vision/semantic_segmentation_mxnet.rst @@ -7,3 +7,34 @@ This is a supervised semantic segmentation algorithm which supports fine-tuning demonstrates how to use the Sagemaker Python SDK for Semantic Segmentation for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - mxnet-semseg-fcn-resnet101-ade + - True + - 1.4.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-semseg-fcn-resnet101-coco + - True + - 1.4.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-semseg-fcn-resnet101-voc + - True + - 1.4.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-semseg-fcn-resnet50-ade + - True + - 1.4.1 + - 2.100.0 + - `GluonCV `__ diff --git a/doc/algorithms/vision/text_embedding_tensorflow_mxnet.rst b/doc/algorithms/vision/text_embedding_tensorflow_mxnet.rst index d015c2ef30..7336c2a341 100644 --- a/doc/algorithms/vision/text_embedding_tensorflow_mxnet.rst +++ b/doc/algorithms/vision/text_embedding_tensorflow_mxnet.rst @@ -7,3 +7,184 @@ This is a supervised text embedding algorithm which supports many pre-trained mo demonstrates how to use the Sagemaker Python SDK for Text Embedding for using these algorithms. For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` + +.. list-table:: Available Models + :widths: 50 20 20 20 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Source + * - mxnet-tcembedding-robertafin-base-uncased + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-tcembedding-robertafin-base-wiki-uncased + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-tcembedding-robertafin-large-uncased + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - mxnet-tcembedding-robertafin-large-wiki-uncased + - False + - 1.2.1 + - 2.100.0 + - `GluonCV `__ + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-4 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-wiki-books-mnli-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-bert-wiki-books-sst2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-base-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-large-2 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-base-1 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-large-1 + - False + - 1.1.1 + - 2.75.0 + - `Tensorflow Hub `__ diff --git a/doc/doc_utils/pretrainedmodels.rst b/doc/doc_utils/pretrainedmodels.rst index e69de29bb2..bfefc56c81 100644 --- a/doc/doc_utils/pretrainedmodels.rst +++ b/doc/doc_utils/pretrainedmodels.rst @@ -0,0 +1,2396 @@ +.. _all-pretrained-models: + +.. |external-link| raw:: html + + + +================================================ +Built-in Algorithms with pre-trained Model Table +================================================ + + The SageMaker Python SDK uses model IDs and model versions to access the necessary + utilities for pre-trained models. This table serves to provide the core material plus + some extra information that can be useful in selecting the correct model ID and + corresponding parameters. + + If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. + We highly suggest pinning an exact model version however. + + These models are also available through the + `JumpStart UI in SageMaker Studio `__ + +.. list-table:: Available Models + :widths: 50 20 20 20 30 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Problem Type + - Source + * - autogluon-classification-ensemble + - True + - 1.1.1 + - 2.103.0 + - Classification + - `GluonCV `__ |external-link| + * - autogluon-regression-ensemble + - True + - 1.1.1 + - 2.103.0 + - Regression + - `GluonCV `__ |external-link| + * - catboost-classification-model + - True + - 1.2.7 + - 2.75.0 + - Classification + - `Catboost `__ |external-link| + * - catboost-regression-model + - True + - 1.2.7 + - 2.75.0 + - Regression + - `Catboost `__ |external-link| + * - huggingface-eqa-bert-base-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-multilingual-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-multilingual-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-cased-whole-word-masking + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-uncased-whole-word-masking + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-multilingual-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilroberta-base + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-base + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-base-openai-detector + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-large + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-ner-distilbert-base-cased-finetuned-conll03-english + - False + - 1.1.0 + - 2.75.0 + - Named Entity Recognition + - `HuggingFace `__ |external-link| + * - huggingface-ner-distilbert-base-uncased-finetuned-conll03-english + - False + - 1.1.0 + - 2.75.0 + - Named Entity Recognition + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-multilingual-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-cased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-uncased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilroberta-base + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-base + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-base-openai-detector + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-large + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-large-openai-detector + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-clm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-enro-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-tlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bart-large-cnn-samsum + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bert-small2bert-small-finetuned-cnn-daily-mail-summarization + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bigbird-pegasus-large-arxiv + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bigbird-pegasus-large-pubmed + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-cnn-12-6 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-cnn-6-6 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-xsum-1-1 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-xsum-12-3 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-bloom-1b1 + - False + - 1.0.1 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-bloom-1b7 + - False + - 1.0.1 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-bloom-560m + - False + - 1.0.1 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-distilgpt2 + - False + - 1.2.1 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-gpt2 + - False + - 1.2.1 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-translation-opus-mt-en-es + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-opus-mt-en-vi + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-base + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-large + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-small + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-txt2img-stable-diffusion-v1-4 + - False + - 1.0.1 + - 2.75.0 + - Source + - `HuggingFace `__ |external-link| + * - lightgbm-classification-model + - True + - 1.2.6 + - 2.75.0 + - Classification + - `LightGBM `__ |external-link| + * - lightgbm-regression-model + - True + - 1.2.6 + - 2.75.0 + - Regression + - `LightGBM `__ |external-link| + * - model-txt2img-stabilityai-stable-diffusion-v1-4 + - False + - 1.0.0 + - 2.75.0 + - Source + - `HuggingFace `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet18-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-resnet18-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-fpn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-fpn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet101-v1d-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet50-v1b-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet50-v1b-voc + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-300-vgg16-atrous-coco + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-300-vgg16-atrous-voc + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-mobilenet1-0-coco + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-mobilenet1-0-voc + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-resnet50-v1-coco + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-resnet50-v1-voc + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-vgg16-atrous-coco + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-vgg16-atrous-voc + - True + - 1.3.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-darknet53-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-darknet53-voc + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-mobilenet1-0-coco + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-mobilenet1-0-voc + - False + - 1.2.1 + - 2.100.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-ade + - True + - 1.4.1 + - 2.100.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-coco + - True + - 1.4.1 + - 2.100.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-voc + - True + - 1.4.1 + - 2.100.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet50-ade + - True + - 1.4.1 + - 2.100.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-base-uncased + - False + - 1.2.1 + - 2.100.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-base-wiki-uncased + - False + - 1.2.1 + - 2.100.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-large-uncased + - False + - 1.2.1 + - 2.100.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-large-wiki-uncased + - False + - 1.2.1 + - 2.100.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - pytorch-eqa-bert-base-cased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-multilingual-cased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-multilingual-uncased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-uncased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased-whole-word-masking + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased-whole-word-masking-finetuned-squad + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased-whole-word-masking + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased-whole-word-masking-finetuned-squad + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-cased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-multilingual-cased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-uncased + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilroberta-base + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-base + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-base-openai-detector + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-large + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-large-openai-detector + - True + - 1.2.1 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-alexnet + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet121 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet161 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet169 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet201 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-googlenet + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-mobilenet-v2 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet101 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet152 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet18 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet34 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet50 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnext101-32x8d + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnext50-32x4d + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-shufflenet-v2-x1-0 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-squeezenet1-0 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-squeezenet1-1 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg11 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg11-bn + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg13 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg13-bn + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg16 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg16-bn + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg19 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg19-bn + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-wide-resnet101-2 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-wide-resnet50-2 + - True + - 2.2.4 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-od-nvidia-ssd + - False + - 1.0.2 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-mobilenet-v3-large-320-fpn + - False + - 1.0.0 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-mobilenet-v3-large-fpn + - False + - 1.0.0 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-resnet50-fpn + - True + - 1.3.2 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-tabtransformerclassification-model + - True + - 1.0.4 + - 2.75.0 + - Source + - `Source `__ |external-link| + * - pytorch-tabtransformerregression-model + - True + - 1.0.3 + - 2.75.0 + - Source + - `Source `__ |external-link| + * - pytorch-textgeneration1-alexa20b + - False + - 1.0.0 + - 2.116.0 + - Source + - `Source `__ |external-link| + * - sklearn-classification-linear + - True + - 1.1.2 + - 2.75.0 + - Classification + - `ScikitLearn `__ |external-link| + * - sklearn-regression-linear + - True + - 1.1.2 + - 2.75.0 + - Regression + - `ScikitLearn `__ |external-link| + * - tensorflow-audioembedding-frill-1 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-audioembedding-trill-3 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-audioembedding-trill-distilled-3 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-audioembedding-trillsson1-1 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-audioembedding-trillsson2-1 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-audioembedding-trillsson3-1 + - False + - 1.0.1 + - 2.80.0 + - Source + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x1-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x3-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x1-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x3-imagenet21k-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b0-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b1-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b2-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b3-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b4-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b5-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b6-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b7-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite0-classification-2 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite1-classification-2 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite2-classification-2 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite3-classification-2 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite4-classification-2 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-resnet-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v1-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v3-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-128-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-160-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-192-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-035-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-050-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-075-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-130-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-140-224-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-101-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-152-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-50-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-101-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-152-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-50-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-resnet-50-classification-1 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-tf2-preview-inception-v3-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-tf2-preview-mobilenet-v2-classification-4 + - True + - 2.0.5 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r101x3-imagenet21k-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r50x3-imagenet21k-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r101x3-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r50x3-ilsvrc2012-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b0-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b1-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b2-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b3-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b6-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite0-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite1-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite2-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite3-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite4-featurevector-2 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v1-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v2-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v3-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-128-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-160-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-192-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-035-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-050-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-075-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-100-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-130-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-140-224-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-101-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-152-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-50-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-101-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-152-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-50-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-resnet-50-featurevector-1 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-tf2-preview-inception-v3-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-tf2-preview-mobilenet-v2-featurevector-4 + - False + - 2.0.2 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-1024x1024-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet101v1-fpn-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v2-512x512-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v2-512x512-kpts-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d0-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d1-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d2-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d3-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d4-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d5-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-inception-resnet-v2-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-inception-resnet-v2-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-800x1333-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet101-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet101-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet152-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet152-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet50-v1-fpn-1024x1024-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet50-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v1-fpn-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-2 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-320x320-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-640x640-1 + - False + - 2.0.2 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od1-ssd-efficientdet-d0-512x512-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-efficientdet-d1-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-efficientdet-d2-768x768-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-efficientdet-d3-896x896-coco17-tpu-32 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-mobilenet-v1-fpn-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-mobilenet-v2-fpnlite-320x320-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-mobilenet-v2-fpnlite-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet101-v1-fpn-1024x1024-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet101-v1-fpn-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet152-v1-fpn-1024x1024-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet152-v1-fpn-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet50-v1-fpn-1024x1024-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-od1-ssd-resnet50-v1-fpn-640x640-coco17-tpu-8 + - True + - 1.0.2 + - 2.75.0 + - Object Detection + - `Source `__ |external-link| + * - tensorflow-spc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-electra-base-1 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-electra-small-1 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-experts-bert-pubmed-1 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-experts-bert-wiki-books-1 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-albert-en-base + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-cased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-uncased-L-24-H-1024-A-16-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-electra-base-1 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-electra-small-1 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-experts-bert-pubmed-1 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-experts-bert-wiki-books-1 + - True + - 2.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-128-A-2 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-256-A-4 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-512-A-8 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-768-A-12 + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-talking-heads-base + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-talking-heads-large + - True + - 1.0.1 + - 2.80.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-4 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-128-A-2-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-256-A-4 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-256-A-4-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-512-A-8-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-768-A-12-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-wiki-books-mnli-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-wiki-books-sst2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-base-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-large-2 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-base-1 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-large-1 + - False + - 1.1.1 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - xgboost-classification-model + - True + - 1.2.3 + - 2.75.0 + - Classification + - `XGBoost `__ |external-link| + * - xgboost-regression-model + - True + - 1.2.3 + - 2.75.0 + - Regression + - `XGBoost `__ |external-link| diff --git a/doc/make.bat b/doc/make.bat index e75ff38602..e81d02187d 100644 --- a/doc/make.bat +++ b/doc/make.bat @@ -1,36 +1,36 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=python -msphinx -) -set SOURCEDIR=. -set BUILDDIR=_build -set SPHINXPROJ=sagemaker - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The Sphinx module was not found. Make sure you have Sphinx installed, - echo.then set the SPHINXBUILD environment variable to point to the full - echo.path of the 'sphinx-build' executable. Alternatively you may add the - echo.Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=python -msphinx +) +set SOURCEDIR=. +set BUILDDIR=_build +set SPHINXPROJ=sagemaker + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The Sphinx module was not found. Make sure you have Sphinx installed, + echo.then set the SPHINXBUILD environment variable to point to the full + echo.path of the 'sphinx-build' executable. Alternatively you may add the + echo.Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index e56bb383fd..3229082c20 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -46,41 +46,53 @@ def _get_session_from_role(role: str, region: str): """ boto_session = boto3.Session(region_name=region) - sts = boto_session.client('sts', - region_name=region, - endpoint_url='https://sts.eu-west-1.amazonaws.com') + sts = boto_session.client( + "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" + ) - metadata = sts.assume_role(RoleArn=role, - RoleSessionName='SagemakerExecution') + metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") - access_key_id = metadata['Credentials']['AccessKeyId'] - secret_access_key = metadata['Credentials']['SecretAccessKey'] - session_token = metadata['Credentials']['SessionToken'] + access_key_id = metadata["Credentials"]["AccessKeyId"] + secret_access_key = metadata["Credentials"]["SecretAccessKey"] + session_token = metadata["Credentials"]["SessionToken"] - boto_session = boto3.session.Session(region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token) + boto_session = boto3.session.Session( + region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + ) # Sessions - sagemaker_client = boto_session.client('sagemaker') - sagemaker_runtime = boto_session.client('sagemaker-runtime') - sagemaker_featurestore_runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') - sagemaker_session = Session(boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=sagemaker_runtime, - sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client) + sagemaker_client = boto_session.client("sagemaker") + sagemaker_runtime = boto_session.client("sagemaker-runtime") + sagemaker_featurestore_runtime_client = boto_session.client( + service_name="sagemaker-featurestore-runtime" + ) + sagemaker_session = Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, + ) return sagemaker_session -def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, - query: str = str('SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE ' - + 'is_deleted=False'), - role: str = None, region: str = None, session=None, - event_time_feature_name: str = None, latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs) -> DataFrame: +def get_feature_group_as_dataframe( + feature_group_name: str, + athena_bucket: str, + query: str = str( + 'SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE ' + "is_deleted=False" + ), + role: str = None, + region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs, +) -> DataFrame: """ Description: Method to run an athena query over a Feature Group in a Feature Store to retrieve its data. @@ -113,38 +125,38 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, if latest_ingestion: if event_time_feature_name is not None: - query += str(f'AND {event_time_feature_name}=(SELECT MAX({event_time_feature_name}) FROM ' - + f'"sagemaker_featurestore"."{feature_group_name}")') - query += ';' + query += str( + f"AND {event_time_feature_name}=(SELECT MAX({event_time_feature_name}) FROM " + + f'"sagemaker_featurestore"."{feature_group_name}")' + ) + query += ";" if session is not None: sagemaker_session = session elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc - logger.info(f'Feature Group used: {feature_group_name}\n') + logger.info(f"Feature Group used: {feature_group_name}\n") - fg = FeatureGroup(name=feature_group_name, - sagemaker_session=sagemaker_session) + fg = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) sample_query = fg.athena_query() - query_string = re.sub(r'#\{(table)\}', sample_query.table_name, query) + query_string = re.sub(r"#\{(table)\}", sample_query.table_name, query) logger.info(f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n") - sample_query.run(query_string=query_string, - output_location=athena_bucket) + sample_query.run(query_string=query_string, output_location=athena_bucket) sample_query.wait() # run Athena query. The output is loaded to a Pandas dataframe. dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs) - logger.info(f'Data shape retrieve from {feature_group_name}: {dataset.shape}') + logger.info(f"Data shape retrieve from {feature_group_name}: {dataset.shape}") return dataset @@ -160,7 +172,7 @@ def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - data.rename(columns=lambda x: x.replace(' ', '_').replace('.', '').lower()[:62], inplace=True) + data.rename(columns=lambda x: x.replace(" ", "_").replace(".", "").lower()[:62], inplace=True) return data @@ -175,23 +187,28 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - for label in data_frame.select_dtypes(['object', 'O']).columns.tolist(): + for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") return data_frame -def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, region: str = None, session=None, - record_id: str = 'record_id', event_id: str = 'data_as_of_date', - verbose: bool = False, - **pandas_read_csv_kwargs) -> FeatureGroup: +def prepare_fg_from_dataframe_or_file( + dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, + region: str = None, + session=None, + record_id: str = "record_id", + event_id: str = "data_as_of_date", + verbose: bool = False, + **pandas_read_csv_kwargs, +) -> FeatureGroup: """ Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to a file with proper dtypes, feature names and mandatory features (record_id, event_id). It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. If record_id or event_id are not specified it will create ones by default with the names - + Args: feature_group_name (str): feature group name @@ -219,11 +236,15 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas if isinstance(dataframe_or_path, DataFrame): data = dataframe_or_path elif isinstance(dataframe_or_path, str): - pandas_read_csv_kwargs.pop('filepath_or_buffer', None) + pandas_read_csv_kwargs.pop("filepath_or_buffer", None) data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs) else: - exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for argument dataframe_or_path.' + - f'\nParameter must be of type pandas.DataFrame or string')) + exc = Exception( + str( + f"Invalid type {type(dataframe_or_path)} for argument dataframe_or_path." + + f"\nParameter must be of type pandas.DataFrame or string" + ) + ) logger.exception(exc) raise exc @@ -231,20 +252,25 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) - if record_id == 'record_id' and record_id not in data.columns: + if record_id == "record_id" and record_id not in data.columns: data[record_id] = data.index lg_uniq = len(data[record_id].unique()) lg_id = len(data[record_id]) if lg_id != lg_uniq: - exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} duplicated rows.' + - f'\nRecord identifier must be unique in each row.')) + exc = Exception( + str( + f"Record identifier {record_id} have {abs(lg_id - lg_uniq)} duplicated rows." + + f"\nRecord identifier must be unique in each row." + ) + ) logger.exception(exc) raise exc if event_id not in data.columns: import time + current_time_sec = int(round(time.time())) data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") @@ -254,13 +280,11 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=sagemaker_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) feature_group.load_feature_definitions(data_frame=data) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index c1fd2165a8..e9c8bc9dd4 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -157,7 +157,7 @@ def as_dataframe(self, **pandas_read_csv_kwargs) -> DataFrame: ) # Assuring delimiter used by default - pandas_read_csv_kwargs.pop('delimiter', None) + pandas_read_csv_kwargs.pop("delimiter", None) return pd.read_csv(output_filename, delimiter=",", **pandas_read_csv_kwargs) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 4f48bc0af2..40d5890810 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -86,7 +86,7 @@ def name_from_base(base, max_length=63, short=False): def unique_name_from_base(base, max_length=63): """Placeholder Docstring""" random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used. - unique = "%04x" % random.randrange(16 ** 4) # 4-digit hex + unique = "%04x" % random.randrange(16**4) # 4-digit hex ts = str(int(time.time())) available_length = max_length - 2 - len(ts) - len(unique) trimmed = base[:available_length] @@ -199,8 +199,8 @@ def secondary_training_status_changed(current_job_description, prev_job_descript """ current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") if ( - current_secondary_status_transitions is None - or len(current_secondary_status_transitions) == 0 + current_secondary_status_transitions is None + or len(current_secondary_status_transitions) == 0 ): return False @@ -213,7 +213,7 @@ def secondary_training_status_changed(current_job_description, prev_job_descript last_message = ( prev_job_secondary_status_transitions[-1]["StatusMessage"] if prev_job_secondary_status_transitions is not None - and len(prev_job_secondary_status_transitions) > 0 + and len(prev_job_secondary_status_transitions) > 0 else "" ) @@ -234,9 +234,9 @@ def secondary_training_status_message(job_description, prev_description): """ if ( - job_description is None - or job_description.get("SecondaryStatusTransitions") is None - or len(job_description.get("SecondaryStatusTransitions")) == 0 + job_description is None + or job_description.get("SecondaryStatusTransitions") is None + or len(job_description.get("SecondaryStatusTransitions")) == 0 ): return "" @@ -256,8 +256,8 @@ def secondary_training_status_message(job_description, prev_description): else: # Secondary status is changed we need to print all the entries. transitions_to_print = current_transitions[ - prev_transitions_num - len(current_transitions): - ] + prev_transitions_num - len(current_transitions) : + ] status_strs = [] for transition in transitions_to_print: @@ -320,7 +320,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3): if obj_sum.key.endswith("/"): continue obj = s3.Object(obj_sum.bucket_name, obj_sum.key) - s3_relative_path = obj_sum.key[len(prefix):].lstrip("/") + s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/") file_path = os.path.join(target, s3_relative_path) try: @@ -377,13 +377,13 @@ def _tmpdir(suffix="", prefix="tmp"): def repack_model( - inference_script, - source_directory, - dependencies, - model_uri, - repacked_model_uri, - sagemaker_session, - kms_key=None, + inference_script, + source_directory, + dependencies, + model_uri, + repacked_model_uri, + sagemaker_session, + kms_key=None, ): """Unpack model tarball and creates a new model tarball with the provided code script. @@ -470,7 +470,7 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): def _create_or_update_code_dir( - model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp + model_dir, inference_script, source_directory, dependencies, sagemaker_session, tmp ): """Placeholder docstring""" code_dir = os.path.join(model_dir, "code") @@ -566,9 +566,9 @@ def sts_regional_endpoint(region): def retries( - max_retry_count, - exception_message_prefix, - seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS, + max_retry_count, + exception_message_prefix, + seconds_to_sleep=DEFAULT_SLEEP_TIME_SECONDS, ): """Retries until max retry count is reached. @@ -684,10 +684,10 @@ class S3DataConfig(DataConfig): """This class extends the DataConfig class to fetch a data config file hosted on S3""" def __init__( - self, - sagemaker_session, - bucket_name, - prefix, + self, + sagemaker_session, + bucket_name, + prefix, ): """Initialize a ``S3DataConfig`` instance. diff --git a/tests/data/upload_data_tests/file1.py b/tests/data/upload_data_tests/file1.py index a2f227b32d..2fb6da6d11 100644 --- a/tests/data/upload_data_tests/file1.py +++ b/tests/data/upload_data_tests/file1.py @@ -1,3 +1,3 @@ -""" -This is a file used in the upload_data tests in the test_session.py unit tests -""" +""" +This is a file used in the upload_data tests in the test_session.py unit tests +""" diff --git a/tests/data/upload_data_tests/file2.py b/tests/data/upload_data_tests/file2.py index a2f227b32d..2fb6da6d11 100644 --- a/tests/data/upload_data_tests/file2.py +++ b/tests/data/upload_data_tests/file2.py @@ -1,3 +1,3 @@ -""" -This is a file used in the upload_data tests in the test_session.py unit tests -""" +""" +This is a file used in the upload_data tests in the test_session.py unit tests +""" diff --git a/tests/data/upload_data_tests/nested_dir/file3.py b/tests/data/upload_data_tests/nested_dir/file3.py index a2f227b32d..2fb6da6d11 100644 --- a/tests/data/upload_data_tests/nested_dir/file3.py +++ b/tests/data/upload_data_tests/nested_dir/file3.py @@ -1,3 +1,3 @@ -""" -This is a file used in the upload_data tests in the test_session.py unit tests -""" +""" +This is a file used in the upload_data tests in the test_session.py unit tests +""" diff --git a/tests/data/upload_data_tests/nested_dir/file4.py b/tests/data/upload_data_tests/nested_dir/file4.py index a2f227b32d..2fb6da6d11 100644 --- a/tests/data/upload_data_tests/nested_dir/file4.py +++ b/tests/data/upload_data_tests/nested_dir/file4.py @@ -1,3 +1,3 @@ -""" -This is a file used in the upload_data tests in the test_session.py unit tests -""" +""" +This is a file used in the upload_data tests in the test_session.py unit tests +""" diff --git a/tests/data/workflow/dummy_data.csv b/tests/data/workflow/dummy_data.csv index 9935d338be..31fdc46ab1 100644 --- a/tests/data/workflow/dummy_data.csv +++ b/tests/data/workflow/dummy_data.csv @@ -1,7 +1,7 @@ -Class,Age,Sex,SurvivalStatus -1st,"Quantity[29., ""Years""]",female,survived -1st,"Quantity[0.9167, ""Years""]",male,survived -2nd,"Quantity[30., ""Years""]",male,died -2nd,"Quantity[28., ""Years""]",female,survived -3rd,"Quantity[16., ""Years""]",male,died +Class,Age,Sex,SurvivalStatus +1st,"Quantity[29., ""Years""]",female,survived +1st,"Quantity[0.9167, ""Years""]",male,survived +2nd,"Quantity[30., ""Years""]",male,died +2nd,"Quantity[28., ""Years""]",female,survived +3rd,"Quantity[16., ""Years""]",male,died 3rd,"Quantity[35., ""Years""]",female,survived \ No newline at end of file diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index baa4d07935..06026f9547 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -150,10 +150,10 @@ def create_table_ddl(): def test_create_feature_store_online_only( - feature_store_session, - role, - feature_group_name, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -172,13 +172,13 @@ def test_create_feature_store_online_only( def test_create_feature_store( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, - record, - create_table_ddl, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, + record, + create_table_ddl, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -229,23 +229,23 @@ def test_create_feature_store( for is_na in nans.items(): assert is_na assert ( - create_table_ddl.format( - feature_group_name=feature_group_name, - region=feature_store_session.boto_session.region_name, - account=feature_store_session.account_id(), - resolved_output_s3_uri=resolved_output_s3_uri, - ) - == feature_group.as_hive_ddl() + create_table_ddl.format( + feature_group_name=feature_group_name, + region=feature_store_session.boto_session.region_name, + account=feature_store_session.account_id(), + resolved_output_s3_uri=resolved_output_s3_uri, + ) + == feature_group.as_hive_ddl() ) assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") def test_update_feature_group( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -269,11 +269,11 @@ def test_update_feature_group( def test_feature_metadata( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -318,11 +318,11 @@ def test_feature_metadata( def test_ingest_without_string_feature( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame_without_string, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame_without_string, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame_without_string) @@ -346,11 +346,11 @@ def test_ingest_without_string_feature( def test_ingest_multi_process( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -397,11 +397,11 @@ def _wait_for_feature_group_update(feature_group: FeatureGroup): def test_get_feature_group_with_role_region( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -420,22 +420,25 @@ def test_get_feature_group_with_role_region( data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True ) - dataset = get_feature_group_as_dataframe(feature_group_name=feature_group_name, - region=region_name, role=role, - event_time_feature_name="feature3", - latest_ingestion=True, - athena_bucket=f'{offline_store_s3_uri}/query') + dataset = get_feature_group_as_dataframe( + feature_group_name=feature_group_name, + region=region_name, + role=role, + event_time_feature_name="feature3", + latest_ingestion=True, + athena_bucket=f"{offline_store_s3_uri}/query", + ) assert dataset.empty == False assert isinstance(dataset, DataFrame) def test_get_feature_group_with_session( - feature_store_session, - role, - feature_group_name, - offline_store_s3_uri, - pandas_data_frame, + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, ): feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) @@ -454,12 +457,14 @@ def test_get_feature_group_with_session( data_frame=pandas_data_frame, max_workers=3, max_processes=2, wait=True ) - dataset = get_feature_group_as_dataframe(feature_group_name=feature_group_name, - session=feature_store_session, - event_time_feature_name="feature3", - latest_ingestion=True, - athena_bucket=f'{offline_store_s3_uri}/query', - low_memory=False) # Using kwargs to pass a parameter to + dataset = get_feature_group_as_dataframe( + feature_group_name=feature_group_name, + session=feature_store_session, + event_time_feature_name="feature3", + latest_ingestion=True, + athena_bucket=f"{offline_store_s3_uri}/query", + low_memory=False, + ) # Using kwargs to pass a parameter to # pandas.read_csv assert dataset.empty == False From b3262306662589d65e36fab2786d06fa4e04456c Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Wed, 23 Nov 2022 23:10:34 +0100 Subject: [PATCH 010/526] change: if latest_ingestion=True, event_time_feature_name in get_feature_group_as_dataframe must specified --- src/sagemaker/feature_group_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index e56bb383fd..ed5b86d336 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -115,6 +115,10 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, if event_time_feature_name is not None: query += str(f'AND {event_time_feature_name}=(SELECT MAX({event_time_feature_name}) FROM ' + f'"sagemaker_featurestore"."{feature_group_name}")') + else: + exc = Exception('Argument event_time_feature_name must be specified when using latest_ingestion=True.') + logger.exception(exc) + raise exc query += ';' if session is not None: From c97c4678b89ff89bda6f1ea4573803d7e1bcb947 Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 2 Dec 2022 12:48:09 -0800 Subject: [PATCH 011/526] fix: type hint of PySparkProcessor __init__ (#3297) From de589419595fbf7bf76e55745f454864cc5998be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Perez?= Date: Fri, 2 Dec 2022 22:01:39 +0100 Subject: [PATCH 012/526] fix: fix PySparkProcessor __init__ params type (#3354) From 41dd3305c2673a4f85e54eec9858f37393c89431 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:18:14 -0800 Subject: [PATCH 013/526] fix: Allow Py 3.7 for MMS Test Docker env (#3080) Co-authored-by: Mufaddal Rohawala --- tests/data/multimodel/container/Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/data/multimodel/container/Dockerfile b/tests/data/multimodel/container/Dockerfile index 4792a429c1..71c38a6605 100644 --- a/tests/data/multimodel/container/Dockerfile +++ b/tests/data/multimodel/container/Dockerfile @@ -1,4 +1,5 @@ -FROM public.ecr.aws/ubuntu/ubuntu:18.04 +# added latest image from https://gallery.ecr.aws/lts/ubuntu +FROM public.ecr.aws/ubuntu/ubuntu:22.04 # Set a docker label to advertise multi-model support on the container LABEL com.amazonaws.sagemaker.capabilities.multi-models=true @@ -15,7 +16,7 @@ RUN apt-get update && \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ - && curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py \ + && curl -O https://bootstrap.pypa.io/pip/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 From 1e23a3f6a7cf554aa537c5c4e21e35548053a6ee Mon Sep 17 00:00:00 2001 From: maldil Date: Fri, 2 Dec 2022 13:19:59 -0800 Subject: [PATCH 014/526] refactoring : using with statement (#3286) --- src/sagemaker/git_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 80bd62d5be..c424753286 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -279,9 +279,8 @@ def _run_clone_command(repo_url, dest_dir): subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) elif repo_url.startswith("git@"): with tempfile.NamedTemporaryFile() as sshnoprompt: - write_pipe = open(sshnoprompt.name, "w") - write_pipe.write("ssh -oBatchMode=yes $@") - write_pipe.close() + with open(sshnoprompt.name, "w") as write_pipe: + write_pipe.write("ssh -oBatchMode=yes $@") os.chmod(sshnoprompt.name, 0o511) my_env["GIT_SSH"] = sshnoprompt.name subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) From 19efadf043678a6c7da4122368d6141e1ec2df10 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:21:34 -0800 Subject: [PATCH 015/526] Update local_requirements.txt PyYAML version (#3095) Co-authored-by: Basil Beirouti Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> --- requirements/extras/local_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5304d82b2a..5f2c85c2fe 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==5.4.1 +PyYAML==6.0.0 From 76f7782db112b38cb7e058dffb1508f2d34fb50b Mon Sep 17 00:00:00 2001 From: arjkesh <33526713+arjkesh@users.noreply.github.com> Date: Fri, 2 Dec 2022 13:22:35 -0800 Subject: [PATCH 016/526] feature: Update TF 2.9 and TF 2.10 inference DLCs (#3465) --- .../image_uri_config/tensorflow.json | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6a01c3e3e6..0122dcd3ca 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -285,7 +285,9 @@ "2.5": "2.5.1", "2.6": "2.6.3", "2.7": "2.7.0", - "2.8": "2.8.0" + "2.8": "2.8.0", + "2.9": "2.9.2", + "2.10": "2.10.0" }, "versions": { "1.10.0": { @@ -1468,6 +1470,68 @@ "us-west-2": "763104351884" }, "repository": "tensorflow-inference" + }, + "2.9.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.10.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" } } }, From fde07388dc26cb270a0a0dfba91439c64e87751a Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Sat, 3 Dec 2022 03:41:10 +0530 Subject: [PATCH 017/526] feature: Added transform with monitoring pipeline step in transformer (#3438) Co-authored-by: Keshav Chandak --- src/sagemaker/transformer.py | 158 +++++++++++++++++++++++++++++++- tests/integ/test_transformer.py | 66 ++++++++++++- 2 files changed, 220 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index cfcc637b99..97278abdd0 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -14,14 +14,17 @@ from __future__ import absolute_import from typing import Union, Optional, List, Dict -from botocore import exceptions +import logging +import copy +import time +from botocore import exceptions from sagemaker.job import _Job -from sagemaker.session import Session +from sagemaker.session import Session, get_execution_role from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join -from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.utils import base_name_from_image, name_from_base @@ -266,6 +269,155 @@ def transform( if wait: self.latest_transform_job.wait(logs=logs) + def transform_with_monitoring( + self, + monitoring_config, + monitoring_resource_config, + data: str, + data_type: str = "S3Prefix", + content_type: str = None, + compression_type: str = None, + split_type: str = None, + input_filter: str = None, + output_filter: str = None, + join_source: str = None, + model_client_config: Dict[str, str] = None, + batch_data_capture_config: BatchDataCaptureConfig = None, + monitor_before_transform: bool = False, + supplied_baseline_statistics: str = None, + supplied_baseline_constraints: str = None, + wait: bool = True, + pipeline_name: str = None, + role: str = None, + ): + """Runs a transform job with monitoring job. + + Note that this function will not start a transform job immediately, + instead, it will create a SageMaker Pipeline and execute it. + If you provide an existing pipeline_name, no new pipeline will be created, otherwise, + each transform_with_monitoring call will create a new pipeline and execute. + + Args: + monitoring_config (Union[ + `sagemaker.workflow.quality_check_step.QualityCheckConfig`, + `sagemaker.workflow.quality_check_step.ClarifyCheckConfig` + ]): the monitoring configuration used for run model monitoring. + monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): + the check job (processing job) cluster resource configuration. + transform_step_args (_JobStepArguments): the transform step transform arguments. + data (str): Input data location in S3 for the transform job + data_type (str): What the S3 location defines (default: 'S3Prefix'). + Valid values: + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix + will be used as inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 + object to use as an input for the transform job. + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if + compressed (default: None). Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object + (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and + 'TFRecord'. + input_filter (str): A JSONPath to select a portion of the input to + pass to the algorithm container for inference. If you omit the + field, it gets the value '$', representing the entire input. + For CSV data, each row is taken as a JSON array, + so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. + CSV data should follow the `RFC format `_. + See `Supported JSONPath Operators + `_ + for a table of supported JSONPath operators. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.features" (default: None). + output_filter (str): A JSONPath to select a portion of the + joined/original output to return as the output. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.prediction" (default: None). + join_source (str): The source of data to be joined to the transform + output. It can be set to 'Input' meaning the entire input record + will be joined to the inference result. You can use OutputFilter + to select the useful portion before uploading to S3. (default: + None). Valid values: Input, None. + model_client_config (dict[str, str]): Model configuration. + Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + (default: ``None``). + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: ``None``). + monitor_before_transform (bgool): If to run data quality + or model explainability monitoring type, + a true value of this flag indicates running the check step before the transform job. + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the + check step when a violation is detected. + supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path + to the supplied statistics object representing the statistics JSON file + which will be used for drift to check (default: None). + supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path + to the supplied constraints object representing the constraints JSON file + which will be used for drift to check (default: None). + wait (bool): To determine if needed to wait for the pipeline execution to complete + pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step + role (str): Execution role + """ + + transformer = self + if not isinstance(self.sagemaker_session, PipelineSession): + sagemaker_session = self.sagemaker_session + self.sagemaker_session = None + transformer = copy.deepcopy(self) + transformer.sagemaker_session = PipelineSession() + self.sagemaker_session = sagemaker_session + + transform_step_args = transformer.transform( + data=data, + data_type=data_type, + content_type=content_type, + compression_type=compression_type, + split_type=split_type, + input_filter=input_filter, + output_filter=output_filter, + batch_data_capture_config=batch_data_capture_config, + join_source=join_source, + model_client_config=model_client_config, + ) + + from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep + + monitoring_batch_step = MonitorBatchTransformStep( + name="MonitorBatchTransformStep", + display_name="MonitorBatchTransformStep", + description="", + transform_step_args=transform_step_args, + monitor_configuration=monitoring_config, + check_job_configuration=monitoring_resource_config, + monitor_before_transform=monitor_before_transform, + supplied_baseline_constraints=supplied_baseline_constraints, + supplied_baseline_statistics=supplied_baseline_statistics, + ) + + pipeline_name = ( + pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}" + ) + # if pipeline exists, just start the execution + from sagemaker.workflow.pipeline import Pipeline + + pipeline = Pipeline( + name=pipeline_name, + steps=[monitoring_batch_step], + sagemaker_session=transformer.sagemaker_session, + ) + pipeline.upsert(role_arn=role if role else get_execution_role()) + execution = pipeline.start() + if wait: + logging.info("Waiting for transform with monitoring to execute ...") + execution.wait() + return execution + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index a0e37ffc77..1de333b987 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -25,6 +25,7 @@ from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator from sagemaker.inputs import BatchDataCaptureConfig +from sagemaker.xgboost import XGBoostModel from sagemaker.utils import unique_name_from_base from tests.integ import ( datasets, @@ -36,7 +37,7 @@ from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources -from sagemaker.model_monitor import DatasetFormat, Statistics +from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints from sagemaker.workflow.check_job_config import CheckJobConfig from sagemaker.workflow.quality_check_step import ( @@ -645,3 +646,66 @@ def _create_transformer_and_transform_job( job_name=unique_name_from_base("test-transform"), ) return transformer + + +def test_transformer_and_monitoring_job( + pipeline_session, + sagemaker_session, + role, + pipeline_name, + check_job_config, + data_bias_check_config, +): + xgb_model_data_s3 = pipeline_session.upload_data( + path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + data_bias_supplied_baseline_constraints = Constraints.from_file_path( + constraints_file_path=os.path.join( + DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + xgb_model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + role=role, + sagemaker_session=sagemaker_session, + entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"), + enable_network_isolation=True, + ) + + xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE) + + transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform" + transformer = Transformer( + model_name=xgb_model.name, + strategy="SingleRecord", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=transform_output, + sagemaker_session=pipeline_session, + ) + + transform_input = pipeline_session.upload_data( + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", + ) + + execution = transformer.transform_with_monitoring( + monitoring_config=data_bias_check_config, + monitoring_resource_config=check_job_config, + data=transform_input, + content_type="text/libsvm", + supplied_baseline_constraints=data_bias_supplied_baseline_constraints, + role=role, + ) + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + + for execution_step in execution_steps: + assert execution_step["StepStatus"] == "Succeeded" + + xgb_model.delete_model() From 7f9f3b04b6704a4d2378b5d9aa3d37de9db45729 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Fri, 2 Dec 2022 17:12:34 -0500 Subject: [PATCH 018/526] fix: Fix bug forcing uploaded tar to be named sourcedir (#3412) --- src/sagemaker/processing.py | 19 +++++++++++-------- tests/integ/test_xgboost.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index db6ce2badd..308783578d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1587,13 +1587,13 @@ def run( # type: ignore[override] framework script to run.Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. When `code` is an S3 URI, ignore `source_dir`, - `dependencies, and `git_config`. If ``source_dir`` is specified, + `dependencies`, and `git_config`. If ``source_dir`` is specified, then ``code`` must point to a file located at the root of ``source_dir``. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other processing source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when processing on Amazon SageMaker (default: None). + point to a file named `sourcedir.tar.gz`. Structure within this directory + are preserved when processing on Amazon SageMaker (default: None). dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be @@ -1730,12 +1730,15 @@ def _pack_and_upload_code( "sagemaker_session unspecified when creating your Processor to have one set up " "automatically." ) + if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix: + # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. + entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( + "sourcedir.tar.gz", + "runproc.sh", + ) + else: + raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") - # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. - entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( - "sourcedir.tar.gz", - "runproc.sh", - ) script = estimator.uploaded_code.script_name s3_runproc_sh = S3Uploader.upload_string_as_file_body( self._generate_framework_script(script), diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py index 733ab4665a..df06a8863a 100644 --- a/tests/integ/test_xgboost.py +++ b/tests/integ/test_xgboost.py @@ -40,6 +40,26 @@ def xgboost_training_job( ) +def test_sourcedir_naming( + sagemaker_session, + xgboost_latest_version, + xgboost_latest_py_version, + cpu_instance_type, +): + with pytest.raises(RuntimeError): + processor = XGBoostProcessor( + framework_version=xgboost_latest_version, + role=ROLE, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + processor.run( + source_dir="s3://bucket/deps.tar.gz", + code="main_script.py", + ) + + @pytest.mark.release def test_framework_processing_job_with_deps( sagemaker_session, From 5d5976726cb8e0cf7143d86b4abb4b665842fd14 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Fri, 2 Dec 2022 14:32:01 -0800 Subject: [PATCH 019/526] feature: Add Code Owners file (#3503) Co-authored-by: Navin Soni --- CODEOWNERS | 1 + requirements/extras/local_requirements.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000..7f7ac28644 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @aws/sagemaker-ml-frameworks diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5f2c85c2fe..5304d82b2a 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==6.0.0 +PyYAML==5.4.1 From fb56c1d670358f478df0a319a6199287890c041a Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Fri, 2 Dec 2022 23:59:54 +0100 Subject: [PATCH 020/526] fix: linting and code style --- src/sagemaker/__init__.py | 2 +- src/sagemaker/feature_group_utils.py | 9 ++++----- src/sagemaker/utils.py | 2 +- tests/integ/test_feature_store.py | 7 ++++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 12fd7051f1..5d3a3680db 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -62,6 +62,6 @@ from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401 from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401 -from sagemaker.feature_group_utils import get_feature_group_as_dataframe +from sagemaker.feature_group_utils import get_feature_group_as_dataframe # noqa: F401 __version__ = importlib_metadata.version("sagemaker") diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index ed5b86d336..f0ddca828b 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -195,7 +195,6 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas a file with proper dtypes, feature names and mandatory features (record_id, event_id). It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. If record_id or event_id are not specified it will create ones by default with the names - Args: feature_group_name (str): feature group name @@ -226,8 +225,8 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas pandas_read_csv_kwargs.pop('filepath_or_buffer', None) data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs) else: - exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for argument dataframe_or_path.' + - f'\nParameter must be of type pandas.DataFrame or string')) + exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for argument dataframe_or_path.' + '\nParameter must be of type pandas.DataFrame or string')) logger.exception(exc) raise exc @@ -242,8 +241,8 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas lg_id = len(data[record_id]) if lg_id != lg_uniq: - exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} duplicated rows.' + - f'\nRecord identifier must be unique in each row.')) + exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} duplicated rows.' + '\nRecord identifier must be unique in each row.')) logger.exception(exc) raise exc diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 5a838fcc21..faf83ed6cc 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -30,7 +30,7 @@ from typing import Optional import botocore -import boto3 + from six.moves.urllib import parse from sagemaker import deprecations diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index e1323b8259..a6bc9f7864 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -28,7 +28,6 @@ from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout from sagemaker.feature_group_utils import get_feature_group_as_dataframe -from sagemaker.utils import get_session_from_role BUCKET_POLICY = { "Version": "2012-10-17", @@ -479,7 +478,8 @@ def test_get_feature_group_with_role_region( athena_bucket=f"{offline_store_s3_uri}/query", ) - assert dataset.empty == False + assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") + assert not dataset.empty assert isinstance(dataset, DataFrame) @@ -517,7 +517,8 @@ def test_get_feature_group_with_session( ) # Using kwargs to pass a parameter to # pandas.read_csv - assert dataset.empty == False + assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") + assert not dataset.empty assert isinstance(dataset, DataFrame) From 9a5996531efec1e484e44ec858d25e95d232c08f Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Sat, 3 Dec 2022 00:01:59 +0100 Subject: [PATCH 021/526] fix: __future__ anotations --- src/sagemaker/feature_group_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index f0ddca828b..6f801ee3b4 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -15,6 +15,8 @@ Utilities for working with FeatureGroups and FeatureStores. """ +from __future__ import absolute_import + import re import logging From 2809a354bd867b1b20993a763b7c0b284d73b64d Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Sat, 3 Dec 2022 01:10:31 +0100 Subject: [PATCH 022/526] fix: docstyle --- src/sagemaker/feature_group_utils.py | 223 ++++++++++-------- src/sagemaker/utils.py | 48 +++- tests/integ/test_feature_store.py | 72 ++++-- .../feature_store/test_feature_definition.py | 7 +- .../feature_store/test_feature_group_utils.py | 18 +- .../feature_store/test_feature_store.py | 120 +++++++--- 6 files changed, 330 insertions(+), 158 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 6f801ee3b4..cfbbbab310 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -10,11 +10,7 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -""" - -Utilities for working with FeatureGroups and FeatureStores. - -""" +"""Utilities for working with FeatureGroups and FeatureStores.""" from __future__ import absolute_import import re @@ -34,10 +30,10 @@ def _get_session_from_role(role: str, region: str): - """ - Method use to get the sagemaker session from a role and a region. Helpful in case it's - invoke from a session with a role without permission it can assume another role temporarily - to perform certain taks. + """Method use to get the sagemaker session from a role and a region. + + Helpful in case it's invoke from a session with a role without permission it can assume + another role temporarily to perform certain taks. Args: role: role name @@ -48,61 +44,75 @@ def _get_session_from_role(role: str, region: str): """ boto_session = boto3.Session(region_name=region) - sts = boto_session.client('sts', - region_name=region, - endpoint_url='https://sts.eu-west-1.amazonaws.com') + sts = boto_session.client( + "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" + ) - metadata = sts.assume_role(RoleArn=role, - RoleSessionName='SagemakerExecution') + metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") - access_key_id = metadata['Credentials']['AccessKeyId'] - secret_access_key = metadata['Credentials']['SecretAccessKey'] - session_token = metadata['Credentials']['SessionToken'] + access_key_id = metadata["Credentials"]["AccessKeyId"] + secret_access_key = metadata["Credentials"]["SecretAccessKey"] + session_token = metadata["Credentials"]["SessionToken"] - boto_session = boto3.session.Session(region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token) + boto_session = boto3.session.Session( + region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + ) # Sessions - sagemaker_client = boto_session.client('sagemaker') - sagemaker_runtime = boto_session.client('sagemaker-runtime') - sagemaker_featurestore_runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') - sagemaker_session = Session(boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=sagemaker_runtime, - sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client) + sagemaker_client = boto_session.client("sagemaker") + sagemaker_runtime = boto_session.client("sagemaker-runtime") + runtime_client = boto_session.client(service_name="sagemaker-featurestore-runtime") + sagemaker_session = Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=runtime_client, + ) return sagemaker_session -def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, - query: str = str('SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE ' - + 'is_deleted=False'), - role: str = None, region: str = None, session=None, - event_time_feature_name: str = None, latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs) -> DataFrame: - """ +def get_feature_group_as_dataframe( + feature_group_name: str, + athena_bucket: str, + query: str = str( + "SELECT * FROM " '"sagemaker_featurestore"."#{table}" ' "WHERE is_deleted=False" + ), + role: str = None, + region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs, +) -> DataFrame: + """Get a feature group as a pandas.DataFrame + Description: - Method to run an athena query over a Feature Group in a Feature Store to retrieve its data. - It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. - Returns a dataframe with the data. + Method to run an athena query over a Feature Group in a Feature Store + to retrieve its data.It needs the sagemaker.Session linked to a role + or the role and region used to work Feature Stores.Returns a dataframe + with the data. Args: region (str): region of the target feature store feature_group_name (str): feature store name - query (str): query to run. By default, it will take the latest ingest with data that wasn't deleted. - If latest_ingestion is False it will take all the data in the feature group that wasn't - deleted. It needs to use the keyword "#{table}" to refer to the table. e.g.: + query (str): query to run. By default, it will take the latest ingest with data that + wasn't deleted. If latest_ingestion is False it will take all the data + in the feature group that wasn't deleted. It needs to use the keyword + "#{table}" to refer to the table. e.g.: 'SELECT * FROM "sagemaker_featurestore"."#{table}"' athena_bucket (str): S3 bucket for running the query role (str): role of the account used to extract data from feature store session (str): session of SageMaker used to work with the feature store - event_time_feature_name (str): eventTimeId feature. Mandatory only if the latest ingestion is True - latest_ingestion (bool): if True it will get the data only from the latest ingestion. If False it - will take whatever is specified in the query, or if not specify it, it will - get all the data that wasn't deleted. + event_time_feature_name (str): eventTimeId feature. Mandatory only if the + latest ingestion is True + latest_ingestion (bool): if True it will get the data only from the latest ingestion. + If False it will take whatever is specified in the query, or + if not specify it, it will get all the data that wasn't deleted. verbose (bool): if True show messages, if False is silent. Returns: @@ -115,50 +125,57 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, if latest_ingestion: if event_time_feature_name is not None: - query += str(f'AND {event_time_feature_name}=(SELECT MAX({event_time_feature_name}) FROM ' - + f'"sagemaker_featurestore"."{feature_group_name}")') + query += str( + f"AND {event_time_feature_name}=(SELECT " + f"MAX({event_time_feature_name}) FROM " + + f'"sagemaker_featurestore"."{feature_group_name}")' + ) else: - exc = Exception('Argument event_time_feature_name must be specified when using latest_ingestion=True.') + exc = Exception( + "Argument event_time_feature_name must be specified " + "when using latest_ingestion=True." + ) logger.exception(exc) raise exc - query += ';' + query += ";" if session is not None: sagemaker_session = session elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc - logger.info(f'Feature Group used: {feature_group_name}\n') + logger.info(f"Feature Group used: {feature_group_name}") - fg = FeatureGroup(name=feature_group_name, - sagemaker_session=sagemaker_session) + fg = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) sample_query = fg.athena_query() - query_string = re.sub(r'#\{(table)\}', sample_query.table_name, query) + query_string = re.sub(r"#\{(table)\}", sample_query.table_name, query) - logger.info(f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n") + logger.info( + f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n" + ) - sample_query.run(query_string=query_string, - output_location=athena_bucket) + sample_query.run(query_string=query_string, output_location=athena_bucket) sample_query.wait() # run Athena query. The output is loaded to a Pandas dataframe. dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs) - logger.info(f'Data shape retrieve from {feature_group_name}: {dataset.shape}') + logger.info(f"Data shape retrieve from {feature_group_name}: {dataset.shape}") return dataset def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: - """ - Module to format correctly the name of the columns of a DataFrame to later generate the features names - of a Feature Group + """Format the column names for a FeatureGroup + + Module to format correctly the name of the columns of a DataFrame + to later generate the features names of a Feature Group Args: data (pandas.DataFrame): dataframe used @@ -166,14 +183,18 @@ def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - data.rename(columns=lambda x: x.replace(' ', '_').replace('.', '').lower()[:62], inplace=True) + data.rename( + columns=lambda x: x.replace(" ", "_").replace(".", "").lower()[:62], + inplace=True, + ) return data def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: - """ - Method to convert 'object' and 'O' column dtypes of a pandas.DataFrame to a valid string type recognized - by Feature Groups. + """Cast properly pandas object types to strings + + Method to convert 'object' and 'O' column dtypes of a pandas.DataFrame to + a valid string type recognized by Feature Groups. Args: data_frame: dataframe used @@ -181,32 +202,41 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - for label in data_frame.select_dtypes(['object', 'O']).columns.tolist(): + for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") return data_frame -def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, region: str = None, session=None, - record_id: str = 'record_id', event_id: str = 'data_as_of_date', - verbose: bool = False, - **pandas_read_csv_kwargs) -> FeatureGroup: - """ - Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to - a file with proper dtypes, feature names and mandatory features (record_id, event_id). - It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. - If record_id or event_id are not specified it will create ones by default with the names +def prepare_fg_from_dataframe_or_file( + dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, + region: str = None, + session=None, + record_id: str = "record_id", + event_id: str = "data_as_of_date", + verbose: bool = False, + **pandas_read_csv_kwargs, +) -> FeatureGroup: + """Module to prepare a dataframe before creating Feature Group + + Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame + or a path to a file with proper dtypes, feature names and mandatory features (record_id, + event_id). It needs the sagemaker.Session linked to a role or the role and region used + to work Feature Stores. If record_id or event_id are not specified it will create ones + by default with the names Args: feature_group_name (str): feature group name dataframe_or_path (str, Path, pandas.DataFrame) : pandas.DataFrame or path to the data verbose (bool) : True for displaying messages, False for silent method. - record_id (str, 'record_id'): (Optional) Feature identifier of the rows. If specified each value of that feature - has to be unique. If not specified or record_id='record_id', then it will create - a new feature from the index of the pandas.DataFrame. - event_id (str) : (Optional) Feature with the time of the creation of data rows. If not specified it - will create one with the current time called `data_as_of_date` + record_id (str, 'record_id'): (Optional) Feature identifier of the rows. If specified each + value of that feature has to be unique. If not specified or + record_id='record_id', then it will create a new feature from + the index of the pandas.DataFrame. + event_id (str) : (Optional) Feature with the time of the creation of data rows. + If not specified it will create one with the current time + called `data_as_of_date` role (str) : role used to get the session. region (str) : region used to get the session. session (str): session of SageMaker used to work with the feature store @@ -219,16 +249,19 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas if verbose: logger.setLevel(logging.INFO) - from sagemaker.feature_store.feature_group import FeatureGroup - if isinstance(dataframe_or_path, DataFrame): data = dataframe_or_path elif isinstance(dataframe_or_path, str): - pandas_read_csv_kwargs.pop('filepath_or_buffer', None) + pandas_read_csv_kwargs.pop("filepath_or_buffer", None) data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs) else: - exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for argument dataframe_or_path.' - '\nParameter must be of type pandas.DataFrame or string')) + exc = Exception( + str( + f"Invalid type {type(dataframe_or_path)} for " + "argument dataframe_or_path. \nParameter must be" + " of type pandas.DataFrame or string" + ) + ) logger.exception(exc) raise exc @@ -236,20 +269,26 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) - if record_id == 'record_id' and record_id not in data.columns: + if record_id == "record_id" and record_id not in data.columns: data[record_id] = data.index lg_uniq = len(data[record_id].unique()) lg_id = len(data[record_id]) if lg_id != lg_uniq: - exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} duplicated rows.' - '\nRecord identifier must be unique in each row.')) + exc = Exception( + str( + f"Record identifier {record_id} have {abs(lg_id - lg_uniq)} " + "duplicated rows. \nRecord identifier must be unique" + " in each row." + ) + ) logger.exception(exc) raise exc if event_id not in data.columns: import time + current_time_sec = int(round(time.time())) data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") @@ -259,7 +298,7 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index faf83ed6cc..b9f57c8650 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -82,17 +82,19 @@ def name_from_base(base, max_length=63, short=False): """ timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp() trimmed_base = base[: max_length - len(timestamp) - 1] - return "{}-{}".format(trimmed_base, timestamp) + return f"{trimmed_base}-{timestamp}" def unique_name_from_base(base, max_length=63): """Placeholder Docstring""" - random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used. + random.seed( + int(uuid.uuid4()) + ) # using uuid to randomize, otherwise system timestamp is used. unique = "%04x" % random.randrange(16**4) # 4-digit hex ts = str(int(time.time())) available_length = max_length - 2 - len(ts) - len(unique) trimmed = base[:available_length] - return "{}-{}-{}".format(trimmed, ts, unique) + return f"{trimmed}-{ts}-{unique}" def base_name_from_image(image, default_base_name=None): @@ -199,7 +201,9 @@ def secondary_training_status_changed(current_job_description, prev_job_descript boolean: Whether the secondary status message of a training job changed or not. """ - current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") + current_secondary_status_transitions = current_job_description.get( + "SecondaryStatusTransitions" + ) if ( current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0 @@ -243,7 +247,9 @@ def secondary_training_status_message(job_description, prev_description): return "" prev_description_secondary_transitions = ( - prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None + prev_description.get("SecondaryStatusTransitions") + if prev_description is not None + else None ) prev_transitions_num = ( len(prev_description["SecondaryStatusTransitions"]) @@ -443,7 +449,9 @@ def repack_model( with tarfile.open(tmp_model_path, mode="w:gz") as t: t.add(model_dir, arcname=os.path.sep) - _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key) + _save_model( + repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key + ) def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): @@ -451,10 +459,14 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): if repacked_model_uri.lower().startswith("s3://"): url = parse.urlparse(repacked_model_uri) bucket, key = url.netloc, url.path.lstrip("/") - new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) + new_key = key.replace( + os.path.basename(key), os.path.basename(repacked_model_uri) + ) settings = ( - sagemaker_session.settings if sagemaker_session is not None else SessionSettings() + sagemaker_session.settings + if sagemaker_session is not None + else SessionSettings() ) encrypt_artifact = settings.encrypt_repacked_artifacts @@ -501,7 +513,9 @@ def _create_or_update_code_dir( for dependency in dependencies: lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): - shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) + shutil.copytree( + dependency, os.path.join(lib_dir, os.path.basename(dependency)) + ) else: if not os.path.exists(lib_dir): os.mkdir(lib_dir) @@ -731,7 +745,11 @@ def get_data_bucket(self, region_requested=None): """ config = self.fetch_data_config() - region = region_requested if region_requested else self.sagemaker_session.boto_region_name + region = ( + region_requested + if region_requested + else self.sagemaker_session.boto_region_name + ) return config[region] if region in config.keys() else config["default"] @@ -766,7 +784,11 @@ def update_container_with_inference_params( if container_list is not None: for obj in container_list: construct_container_object( - obj, data_input_configuration, framework, framework_version, nearest_model_name + obj, + data_input_configuration, + framework, + framework_version, + nearest_model_name, ) if container_def is not None: @@ -833,7 +855,9 @@ def construct_container_object( return obj -def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None): +def pop_out_unused_kwarg( + arg_name: str, kwargs: dict, override_val: Optional[str] = None +): """Pop out the unused key-word argument and give a warning. Args: diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index a6bc9f7864..55f89acf87 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -24,7 +24,11 @@ from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup -from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum +from sagemaker.feature_store.inputs import ( + FeatureValue, + FeatureParameter, + TableFormatEnum, +) from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout from sagemaker.feature_group_utils import get_feature_group_as_dataframe @@ -38,7 +42,9 @@ "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": ["s3:PutObject", "s3:PutObjectAcl"], "Resource": "arn:aws:s3:::{bucket_name}-{region_name}/*", - "Condition": {"StringEquals": {"s3:x-amz-acl": "bucket-owner-full-control"}}, + "Condition": { + "StringEquals": {"s3:x-amz-acl": "bucket-owner-full-control"} + }, }, { "Sid": "FeatureStoreOfflineStoreS3BucketPolicy", @@ -154,7 +160,9 @@ def test_create_feature_store_online_only( feature_group_name, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -179,7 +187,9 @@ def test_create_feature_store( record, create_table_ddl, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -216,9 +226,9 @@ def test_create_feature_store( output_location=f"{offline_store_s3_uri}/query_results", ) athena_query.wait() - assert "SUCCEEDED" == athena_query.get_query_execution().get("QueryExecution").get( - "Status" - ).get("State") + assert "SUCCEEDED" == athena_query.get_query_execution().get( + "QueryExecution" + ).get("Status").get("State") df = athena_query.as_dataframe() print(f"Found {df.shape[0]} records.") time.sleep(60) @@ -246,7 +256,9 @@ def test_create_feature_group_iceberg_table_format( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -260,7 +272,9 @@ def test_create_feature_group_iceberg_table_format( ) _wait_for_feature_group_create(feature_group) - table_format = feature_group.describe().get("OfflineStoreConfig").get("TableFormat") + table_format = ( + feature_group.describe().get("OfflineStoreConfig").get("TableFormat") + ) assert table_format == "Iceberg" @@ -271,7 +285,9 @@ def test_create_feature_group_glue_table_format( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -285,7 +301,9 @@ def test_create_feature_group_glue_table_format( ) _wait_for_feature_group_create(feature_group) - table_format = feature_group.describe().get("OfflineStoreConfig").get("TableFormat") + table_format = ( + feature_group.describe().get("OfflineStoreConfig").get("TableFormat") + ) assert table_format == "Glue" @@ -296,7 +314,9 @@ def test_update_feature_group( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -314,7 +334,9 @@ def test_update_feature_group( feature_group.update(new_features) _wait_for_feature_group_update(feature_group) feature_definitions = feature_group.describe().get("FeatureDefinitions") - assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) + assert any( + [True for elem in feature_definitions if new_feature_name in elem.values()] + ) def test_feature_metadata( @@ -324,7 +346,9 @@ def test_feature_metadata( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -373,7 +397,9 @@ def test_ingest_without_string_feature( offline_store_s3_uri, pandas_data_frame_without_string, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame_without_string) with cleanup_feature_group(feature_group): @@ -401,7 +427,9 @@ def test_ingest_multi_process( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -452,7 +480,9 @@ def test_get_feature_group_with_role_region( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -490,7 +520,9 @@ def test_get_feature_group_with_session( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup( + name=feature_group_name, sagemaker_session=feature_store_session + ) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -530,4 +562,6 @@ def cleanup_feature_group(feature_group: FeatureGroup): try: feature_group.delete() except Exception: - raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") + raise RuntimeError( + f"Failed to delete feature group with name {feature_group.name}" + ) diff --git a/tests/unit/sagemaker/feature_store/test_feature_definition.py b/tests/unit/sagemaker/feature_store/test_feature_definition.py index d2f6a24be7..3c8a4b9d81 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_definition.py +++ b/tests/unit/sagemaker/feature_store/test_feature_definition.py @@ -26,12 +26,13 @@ def ordered(obj): return sorted((k, ordered(v)) for k, v in obj.items()) if isinstance(obj, list): return sorted(ordered(x) for x in obj) - else: - return obj + return obj def test_feature_definition(): - definition = FeatureDefinition(feature_name="MyFeature", feature_type=FeatureTypeEnum.INTEGRAL) + definition = FeatureDefinition( + feature_name="MyFeature", feature_type=FeatureTypeEnum.INTEGRAL + ) assert ordered(definition.to_dict()) == ordered( { "FeatureName": "MyFeature", diff --git a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py index 632c4a0f26..377ec090da 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py @@ -11,13 +11,17 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. +"""Test for Feature Group Utils""" from __future__ import absolute_import import pandas as pd import pytest from mock import Mock -from sagemaker.feature_group_utils import _cast_object_to_string, prepare_fg_from_dataframe_or_file +from sagemaker.feature_group_utils import ( + _cast_object_to_string, + prepare_fg_from_dataframe_or_file, +) from sagemaker.feature_store.feature_definition import ( FeatureTypeEnum, ) @@ -27,17 +31,23 @@ class PicklableMock(Mock): + """Mock class use for tests""" + def __reduce__(self): + """Method from class Mock""" return (Mock, ()) @pytest.fixture def sagemaker_session_mock(): + """Fixture Mock class""" return Mock() def test_convert_unsupported_types_to_supported(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="FailedGroup", sagemaker_session=sagemaker_session_mock + ) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), @@ -69,7 +79,9 @@ def test_prepare_fg_from_dataframe(sagemaker_session_mock): ) feature_group = prepare_fg_from_dataframe_or_file( - dataframe_or_path=df, session=sagemaker_session_mock, feature_group_name="testFG" + dataframe_or_path=df, + session=sagemaker_session_mock, + feature_group_name="testFG", ) names = [fd.feature_name for fd in feature_group.feature_definitions] diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 92ba35573c..180651e3bc 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. +"""Test for Feature Store""" from __future__ import absolute_import @@ -38,6 +39,8 @@ class PicklableMock(Mock): + """Mock class use for tests""" + def __reduce__(self): return (Mock, ()) @@ -93,7 +96,9 @@ def create_table_ddl(): def test_feature_store_create( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -121,7 +126,9 @@ def test_feature_store_create( def test_feature_store_create_iceberg_table_format( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -152,7 +159,9 @@ def test_feature_store_create_iceberg_table_format( def test_feature_store_create_glue_table_format( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -183,7 +192,9 @@ def test_feature_store_create_glue_table_format( def test_feature_store_create_online_only( sagemaker_session_mock, role_arn, feature_group_dummy_definitions ): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=False, @@ -205,7 +216,9 @@ def test_feature_store_create_online_only( def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.delete() sagemaker_session_mock.delete_feature_group.assert_called_with( feature_group_name="MyFeatureGroup" @@ -213,7 +226,9 @@ def test_feature_store_delete(sagemaker_session_mock): def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.describe() sagemaker_session_mock.describe_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", next_token=None @@ -221,7 +236,9 @@ def test_feature_store_describe(sagemaker_session_mock): def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.update(feature_group_dummy_definitions) sagemaker_session_mock.update_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", @@ -230,7 +247,9 @@ def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_defini def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) parameter_additions = [FeatureParameter(key="key1", value="value1")] parameter_removals = ["key2"] @@ -248,7 +267,9 @@ def test_feature_metadata_update(sagemaker_session_mock): parameter_additions=[pa.to_dict() for pa in parameter_additions], parameter_removals=parameter_removals, ) - feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + feature_group.update_feature_metadata( + feature_name="Feature1", description="TestDescription" + ) sagemaker_session_mock.update_feature_metadata.assert_called_with( feature_group_name="MyFeatureGroup", feature_name="Feature1", @@ -259,7 +280,9 @@ def test_feature_metadata_update(sagemaker_session_mock): def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.describe_feature_metadata(feature_name="Feature1") sagemaker_session_mock.describe_feature_metadata.assert_called_with( feature_group_name="MyFeatureGroup", feature_name="Feature1" @@ -267,7 +290,9 @@ def test_feature_metadata_describe(sagemaker_session_mock): def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.put_record(record=[]) sagemaker_session_mock.put_record.assert_called_with( feature_group_name="MyFeatureGroup", record=[] @@ -275,7 +300,9 @@ def test_put_record(sagemaker_session_mock): def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="SomeGroup", sagemaker_session=sagemaker_session_mock + ) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), @@ -295,7 +322,9 @@ def test_load_feature_definition(sagemaker_session_mock): def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="FailedGroup", sagemaker_session=sagemaker_session_mock + ) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), @@ -305,11 +334,16 @@ def test_load_feature_definition_unsupported_types(sagemaker_session_mock): ) with pytest.raises(ValueError) as error: feature_group.load_feature_definitions(data_frame=df) - assert "Failed to infer Feature type based on dtype object for column object." in str(error) + assert ( + "Failed to infer Feature type based on dtype object for column object." + in str(error) + ) def test_ingest_zero_processes(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) df = Mock() with pytest.raises(RuntimeError) as error: feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) @@ -318,7 +352,9 @@ def test_ingest_zero_processes(): def test_ingest_zero_workers(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) df = Mock() with pytest.raises(RuntimeError) as error: feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) @@ -327,13 +363,19 @@ def test_ingest_zero_workers(): @patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): +def test_ingest( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( fs_runtime_client_config_mock ) - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) + df = pd.DataFrame( + dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)) + ) mock_ingestion_manager_instance = Mock() ingestion_manager_init.return_value = mock_ingestion_manager_instance @@ -359,8 +401,12 @@ def test_ingest_with_profile_name( fs_runtime_client_config_mock ) - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) + df = pd.DataFrame( + dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)) + ) mock_ingestion_manager_instance = Mock() ingestion_manager_init.return_value = mock_ingestion_manager_instance @@ -392,7 +438,9 @@ def test_as_hive_ddl_with_default_values( sagemaker_session_mock.account_id.return_value = "1234" sagemaker_session_mock.boto_session.region_name = "us-west-2" - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions assert ( create_table_ddl.format( @@ -406,7 +454,9 @@ def test_as_hive_ddl_with_default_values( ) -def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): +def test_as_hive_ddl( + create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock +): sagemaker_session_mock.describe_feature_group.return_value = { "OfflineStoreConfig": { "S3StorageConfig": { @@ -418,7 +468,9 @@ def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemake sagemaker_session_mock.account_id.return_value = "1234" sagemaker_session_mock.boto_session.region_name = "us-west-2" - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group = FeatureGroup( + name="MyGroup", sagemaker_session=sagemaker_session_mock + ) feature_group.feature_definitions = feature_group_dummy_definitions assert create_table_ddl.format( database="MyDatabase", @@ -442,7 +494,9 @@ def test_ingestion_manager_run_success(): ) manager.run(df) - manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) + manager._run_multi_process.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) @patch( @@ -534,9 +588,13 @@ def query(sagemaker_session_mock): def test_athena_query_run(sagemaker_session_mock, query): WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} + sagemaker_session_mock.start_query_execution.return_value = { + "QueryExecutionId": "query_id" + } query.run( - query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP + query_string="query", + output_location="s3://some-bucket/some-path", + workgroup=WORKGROUP, ) sagemaker_session_mock.start_query_execution.assert_called_with( catalog="catalog", @@ -554,13 +612,17 @@ def test_athena_query_run(sagemaker_session_mock, query): def test_athena_query_wait(sagemaker_session_mock, query): query._current_query_execution_id = "query_id" query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") + sagemaker_session_mock.wait_for_athena_query.assert_called_with( + query_execution_id="query_id" + ) def test_athena_query_get_query_execution(sagemaker_session_mock, query): query._current_query_execution_id = "query_id" query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") + sagemaker_session_mock.get_query_execution.assert_called_with( + query_execution_id="query_id" + ) @patch("tempfile.gettempdir", Mock(return_value="tmp")) From 205888a707e97bea7ffb01ddd4e4459f393bde80 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Sat, 3 Dec 2022 01:13:42 +0100 Subject: [PATCH 023/526] fix: docstyle --- tests/unit/sagemaker/feature_store/test_inputs.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/unit/sagemaker/feature_store/test_inputs.py b/tests/unit/sagemaker/feature_store/test_inputs.py index 5d35b11bb9..5da21be61f 100644 --- a/tests/unit/sagemaker/feature_store/test_inputs.py +++ b/tests/unit/sagemaker/feature_store/test_inputs.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. # language governing permissions and limitations under the License. +"""Test for Feature Inputs""" from __future__ import absolute_import from sagemaker.feature_store.inputs import ( @@ -30,8 +31,7 @@ def ordered(obj): return sorted((k, ordered(v)) for k, v in obj.items()) if isinstance(obj, list): return sorted(ordered(x) for x in obj) - else: - return obj + return obj def test_online_store_security_config(): @@ -89,7 +89,8 @@ def test_offline_data_store_config(): def test_offline_data_store_config_with_glue_table_format(): config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig(s3_uri="uri"), table_format=TableFormatEnum.GLUE + s3_storage_config=S3StorageConfig(s3_uri="uri"), + table_format=TableFormatEnum.GLUE, ) assert ordered(config.to_dict()) == ordered( { @@ -102,7 +103,8 @@ def test_offline_data_store_config_with_glue_table_format(): def test_offline_data_store_config_with_iceberg_table_format(): config = OfflineStoreConfig( - s3_storage_config=S3StorageConfig(s3_uri="uri"), table_format=TableFormatEnum.ICEBERG + s3_storage_config=S3StorageConfig(s3_uri="uri"), + table_format=TableFormatEnum.ICEBERG, ) assert ordered(config.to_dict()) == ordered( { From aa1466f9eda1c4b691ed68baab18b092813e1292 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Sat, 3 Dec 2022 01:28:54 +0100 Subject: [PATCH 024/526] fix: linting - errors with typing library --- .pylintrc | 1 + src/sagemaker/feature_group_utils.py | 158 ++++++++----------- src/sagemaker/feature_store/feature_group.py | 7 +- 3 files changed, 72 insertions(+), 94 deletions(-) diff --git a/.pylintrc b/.pylintrc index 9c16afcc22..4ac25d48c3 100644 --- a/.pylintrc +++ b/.pylintrc @@ -94,6 +94,7 @@ disable= useless-object-inheritance, # TODO: Enable this check and fix code once Python 2 is no longer supported. super-with-arguments, raise-missing-from, + E1136 # typing unsubscriptable-oject [REPORTS] # Set the output format. Available formats are text, parseable, colorized, msvs diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index cfbbbab310..874774a28d 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -44,51 +44,44 @@ def _get_session_from_role(role: str, region: str): """ boto_session = boto3.Session(region_name=region) - sts = boto_session.client( - "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" - ) + sts = boto_session.client('sts', + region_name=region, + endpoint_url='https://sts.eu-west-1.amazonaws.com') - metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") + metadata = sts.assume_role(RoleArn=role, + RoleSessionName='SagemakerExecution') - access_key_id = metadata["Credentials"]["AccessKeyId"] - secret_access_key = metadata["Credentials"]["SecretAccessKey"] - session_token = metadata["Credentials"]["SessionToken"] + access_key_id = metadata['Credentials']['AccessKeyId'] + secret_access_key = metadata['Credentials']['SecretAccessKey'] + session_token = metadata['Credentials']['SessionToken'] - boto_session = boto3.session.Session( - region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token, - ) + boto_session = boto3.session.Session(region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token) # Sessions - sagemaker_client = boto_session.client("sagemaker") - sagemaker_runtime = boto_session.client("sagemaker-runtime") - runtime_client = boto_session.client(service_name="sagemaker-featurestore-runtime") - sagemaker_session = Session( - boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=sagemaker_runtime, - sagemaker_featurestore_runtime_client=runtime_client, - ) + sagemaker_client = boto_session.client('sagemaker') + sagemaker_runtime = boto_session.client('sagemaker-runtime') + runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') + sagemaker_session = Session(boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=runtime_client) return sagemaker_session -def get_feature_group_as_dataframe( - feature_group_name: str, - athena_bucket: str, - query: str = str( - "SELECT * FROM " '"sagemaker_featurestore"."#{table}" ' "WHERE is_deleted=False" - ), - role: str = None, - region: str = None, - session=None, - event_time_feature_name: str = None, - latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs, -) -> DataFrame: +def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, + query: str = str('SELECT * FROM ' + '"sagemaker_featurestore"."#{table}" ' + 'WHERE is_deleted=False'), + role: str = None, region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs) -> DataFrame: """Get a feature group as a pandas.DataFrame Description: @@ -125,48 +118,47 @@ def get_feature_group_as_dataframe( if latest_ingestion: if event_time_feature_name is not None: - query += str( - f"AND {event_time_feature_name}=(SELECT " - f"MAX({event_time_feature_name}) FROM " - + f'"sagemaker_featurestore"."{feature_group_name}")' - ) + query += str(f'AND {event_time_feature_name}=(SELECT ' + f'MAX({event_time_feature_name}) FROM ' + + f'"sagemaker_featurestore"."{feature_group_name}")') else: - exc = Exception( - "Argument event_time_feature_name must be specified " - "when using latest_ingestion=True." - ) + exc = Exception('Argument event_time_feature_name must be specified ' + 'when using latest_ingestion=True.') logger.exception(exc) raise exc - query += ";" + query += ';' if session is not None: sagemaker_session = session elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception("Argument Session or role and region must be specified.") + exc = Exception('Argument Session or role and region must be specified.') logger.exception(exc) raise exc - logger.info(f"Feature Group used: {feature_group_name}") + msg = f'Feature Group used: {feature_group_name}' + logger.info(msg) - fg = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) + fg = FeatureGroup(name=feature_group_name, + sagemaker_session=sagemaker_session) sample_query = fg.athena_query() - query_string = re.sub(r"#\{(table)\}", sample_query.table_name, query) + query_string = re.sub(r'#\{(table)\}', sample_query.table_name, query) - logger.info( - f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n" - ) + msg = f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n" + logger.info(msg) - sample_query.run(query_string=query_string, output_location=athena_bucket) + sample_query.run(query_string=query_string, + output_location=athena_bucket) sample_query.wait() # run Athena query. The output is loaded to a Pandas dataframe. dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs) - logger.info(f"Data shape retrieve from {feature_group_name}: {dataset.shape}") + msg = f'Data shape retrieve from {feature_group_name}: {dataset.shape}' + logger.info(msg) return dataset @@ -183,10 +175,7 @@ def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - data.rename( - columns=lambda x: x.replace(" ", "_").replace(".", "").lower()[:62], - inplace=True, - ) + data.rename(columns=lambda x: x.replace(' ', '_').replace('.', '').lower()[:62], inplace=True) return data @@ -202,29 +191,25 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): + for label in data_frame.select_dtypes(['object', 'O']).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") return data_frame -def prepare_fg_from_dataframe_or_file( - dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, - region: str = None, - session=None, - record_id: str = "record_id", - event_id: str = "data_as_of_date", - verbose: bool = False, - **pandas_read_csv_kwargs, -) -> FeatureGroup: - """Module to prepare a dataframe before creating Feature Group +def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, region: str = None, session=None, + record_id: str = 'record_id', + event_id: str = 'data_as_of_date', + verbose: bool = False, + **pandas_read_csv_kwargs) -> FeatureGroup: + """Module to prepare a dataframe before creating Feature Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to a file with proper dtypes, feature names and mandatory features (record_id, event_id). It needs the sagemaker.Session linked to a role or the role and region used to work Feature Stores. If record_id or event_id are not specified it will create ones - by default with the names + by default with the names 'record_id' and 'data_as_of_date'. Args: feature_group_name (str): feature group name @@ -252,16 +237,12 @@ def prepare_fg_from_dataframe_or_file( if isinstance(dataframe_or_path, DataFrame): data = dataframe_or_path elif isinstance(dataframe_or_path, str): - pandas_read_csv_kwargs.pop("filepath_or_buffer", None) + pandas_read_csv_kwargs.pop('filepath_or_buffer', None) data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs) else: - exc = Exception( - str( - f"Invalid type {type(dataframe_or_path)} for " - "argument dataframe_or_path. \nParameter must be" - " of type pandas.DataFrame or string" - ) - ) + exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for ' + 'argument dataframe_or_path. \nParameter must be' + ' of type pandas.DataFrame or string')) logger.exception(exc) raise exc @@ -269,26 +250,21 @@ def prepare_fg_from_dataframe_or_file( data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) - if record_id == "record_id" and record_id not in data.columns: + if record_id == 'record_id' and record_id not in data.columns: data[record_id] = data.index lg_uniq = len(data[record_id].unique()) lg_id = len(data[record_id]) if lg_id != lg_uniq: - exc = Exception( - str( - f"Record identifier {record_id} have {abs(lg_id - lg_uniq)} " - "duplicated rows. \nRecord identifier must be unique" - " in each row." - ) - ) + exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} ' + 'duplicated rows. \nRecord identifier must be unique' + ' in each row.')) logger.exception(exc) raise exc if event_id not in data.columns: import time - current_time_sec = int(round(time.time())) data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") @@ -298,7 +274,7 @@ def prepare_fg_from_dataframe_or_file( elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception("Argument Session or role and region must be specified.") + exc = Exception('Argument Session or role and region must be specified.') logger.exception(exc) raise exc diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 21837d1d4d..1de5d9a5c5 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -133,9 +133,10 @@ def as_dataframe(self, **pandas_read_csv_kwargs) -> DataFrame: """Download the result of the current query and load it into a DataFrame. Args: - pandas_read_csv_kwargs: key arguments used for the method pandas.read_csv to be able to have a better - tuning on data. For more info about this methods visit: - https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html + pandas_read_csv_kwargs: key arguments used for the method pandas.read_csv + to be able to have a better tuning on data. For more info + about this methods visit: + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html Returns: A pandas DataFrame contains the query result. """ From ee43d123361b0214c5824be63698be9ee439f09a Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Sat, 3 Dec 2022 01:33:49 +0100 Subject: [PATCH 025/526] fix: black style formated --- src/sagemaker/feature_group_utils.py | 148 ++++++++++-------- src/sagemaker/utils.py | 38 ++--- tests/integ/test_feature_store.py | 66 +++----- .../feature_store/test_feature_definition.py | 4 +- .../feature_store/test_feature_group_utils.py | 4 +- .../feature_store/test_feature_store.py | 113 ++++--------- 6 files changed, 141 insertions(+), 232 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 874774a28d..af3bf6327d 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -44,44 +44,51 @@ def _get_session_from_role(role: str, region: str): """ boto_session = boto3.Session(region_name=region) - sts = boto_session.client('sts', - region_name=region, - endpoint_url='https://sts.eu-west-1.amazonaws.com') + sts = boto_session.client( + "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" + ) - metadata = sts.assume_role(RoleArn=role, - RoleSessionName='SagemakerExecution') + metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") - access_key_id = metadata['Credentials']['AccessKeyId'] - secret_access_key = metadata['Credentials']['SecretAccessKey'] - session_token = metadata['Credentials']['SessionToken'] + access_key_id = metadata["Credentials"]["AccessKeyId"] + secret_access_key = metadata["Credentials"]["SecretAccessKey"] + session_token = metadata["Credentials"]["SessionToken"] - boto_session = boto3.session.Session(region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token) + boto_session = boto3.session.Session( + region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + ) # Sessions - sagemaker_client = boto_session.client('sagemaker') - sagemaker_runtime = boto_session.client('sagemaker-runtime') - runtime_client = boto_session.client(service_name='sagemaker-featurestore-runtime') - sagemaker_session = Session(boto_session=boto_session, - sagemaker_client=sagemaker_client, - sagemaker_runtime_client=sagemaker_runtime, - sagemaker_featurestore_runtime_client=runtime_client) + sagemaker_client = boto_session.client("sagemaker") + sagemaker_runtime = boto_session.client("sagemaker-runtime") + runtime_client = boto_session.client(service_name="sagemaker-featurestore-runtime") + sagemaker_session = Session( + boto_session=boto_session, + sagemaker_client=sagemaker_client, + sagemaker_runtime_client=sagemaker_runtime, + sagemaker_featurestore_runtime_client=runtime_client, + ) return sagemaker_session -def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, - query: str = str('SELECT * FROM ' - '"sagemaker_featurestore"."#{table}" ' - 'WHERE is_deleted=False'), - role: str = None, region: str = None, - session=None, - event_time_feature_name: str = None, - latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs) -> DataFrame: +def get_feature_group_as_dataframe( + feature_group_name: str, + athena_bucket: str, + query: str = str( + "SELECT * FROM " '"sagemaker_featurestore"."#{table}" ' "WHERE is_deleted=False" + ), + role: str = None, + region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs, +) -> DataFrame: """Get a feature group as a pandas.DataFrame Description: @@ -118,46 +125,48 @@ def get_feature_group_as_dataframe(feature_group_name: str, athena_bucket: str, if latest_ingestion: if event_time_feature_name is not None: - query += str(f'AND {event_time_feature_name}=(SELECT ' - f'MAX({event_time_feature_name}) FROM ' - + f'"sagemaker_featurestore"."{feature_group_name}")') + query += str( + f"AND {event_time_feature_name}=(SELECT " + f"MAX({event_time_feature_name}) FROM " + + f'"sagemaker_featurestore"."{feature_group_name}")' + ) else: - exc = Exception('Argument event_time_feature_name must be specified ' - 'when using latest_ingestion=True.') + exc = Exception( + "Argument event_time_feature_name must be specified " + "when using latest_ingestion=True." + ) logger.exception(exc) raise exc - query += ';' + query += ";" if session is not None: sagemaker_session = session elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc - msg = f'Feature Group used: {feature_group_name}' + msg = f"Feature Group used: {feature_group_name}" logger.info(msg) - fg = FeatureGroup(name=feature_group_name, - sagemaker_session=sagemaker_session) + fg = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) sample_query = fg.athena_query() - query_string = re.sub(r'#\{(table)\}', sample_query.table_name, query) + query_string = re.sub(r"#\{(table)\}", sample_query.table_name, query) msg = f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n" logger.info(msg) - sample_query.run(query_string=query_string, - output_location=athena_bucket) + sample_query.run(query_string=query_string, output_location=athena_bucket) sample_query.wait() # run Athena query. The output is loaded to a Pandas dataframe. dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs) - msg = f'Data shape retrieve from {feature_group_name}: {dataset.shape}' + msg = f"Data shape retrieve from {feature_group_name}: {dataset.shape}" logger.info(msg) return dataset @@ -175,7 +184,7 @@ def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - data.rename(columns=lambda x: x.replace(' ', '_').replace('.', '').lower()[:62], inplace=True) + data.rename(columns=lambda x: x.replace(" ", "_").replace(".", "").lower()[:62], inplace=True) return data @@ -191,18 +200,22 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: Returns: pandas.DataFrame """ - for label in data_frame.select_dtypes(['object', 'O']).columns.tolist(): + for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") return data_frame -def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, region: str = None, session=None, - record_id: str = 'record_id', - event_id: str = 'data_as_of_date', - verbose: bool = False, - **pandas_read_csv_kwargs) -> FeatureGroup: +def prepare_fg_from_dataframe_or_file( + dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, + region: str = None, + session=None, + record_id: str = "record_id", + event_id: str = "data_as_of_date", + verbose: bool = False, + **pandas_read_csv_kwargs, +) -> FeatureGroup: """Module to prepare a dataframe before creating Feature Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame @@ -237,12 +250,16 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas if isinstance(dataframe_or_path, DataFrame): data = dataframe_or_path elif isinstance(dataframe_or_path, str): - pandas_read_csv_kwargs.pop('filepath_or_buffer', None) + pandas_read_csv_kwargs.pop("filepath_or_buffer", None) data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs) else: - exc = Exception(str(f'Invalid type {type(dataframe_or_path)} for ' - 'argument dataframe_or_path. \nParameter must be' - ' of type pandas.DataFrame or string')) + exc = Exception( + str( + f"Invalid type {type(dataframe_or_path)} for " + "argument dataframe_or_path. \nParameter must be" + " of type pandas.DataFrame or string" + ) + ) logger.exception(exc) raise exc @@ -250,21 +267,26 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas data = _format_column_names(data=data) data = _cast_object_to_string(data_frame=data) - if record_id == 'record_id' and record_id not in data.columns: + if record_id == "record_id" and record_id not in data.columns: data[record_id] = data.index lg_uniq = len(data[record_id].unique()) lg_id = len(data[record_id]) if lg_id != lg_uniq: - exc = Exception(str(f'Record identifier {record_id} have {abs(lg_id - lg_uniq)} ' - 'duplicated rows. \nRecord identifier must be unique' - ' in each row.')) + exc = Exception( + str( + f"Record identifier {record_id} have {abs(lg_id - lg_uniq)} " + "duplicated rows. \nRecord identifier must be unique" + " in each row." + ) + ) logger.exception(exc) raise exc if event_id not in data.columns: import time + current_time_sec = int(round(time.time())) data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") @@ -274,13 +296,11 @@ def prepare_fg_from_dataframe_or_file(dataframe_or_path: Union[str, Path, pandas elif role is not None and region is not None: sagemaker_session = _get_session_from_role(role=role, region=region) else: - exc = Exception('Argument Session or role and region must be specified.') + exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) raise exc - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=sagemaker_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session) feature_group.load_feature_definitions(data_frame=data) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index b9f57c8650..bfd3f155ee 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -87,9 +87,7 @@ def name_from_base(base, max_length=63, short=False): def unique_name_from_base(base, max_length=63): """Placeholder Docstring""" - random.seed( - int(uuid.uuid4()) - ) # using uuid to randomize, otherwise system timestamp is used. + random.seed(int(uuid.uuid4())) # using uuid to randomize, otherwise system timestamp is used. unique = "%04x" % random.randrange(16**4) # 4-digit hex ts = str(int(time.time())) available_length = max_length - 2 - len(ts) - len(unique) @@ -201,9 +199,7 @@ def secondary_training_status_changed(current_job_description, prev_job_descript boolean: Whether the secondary status message of a training job changed or not. """ - current_secondary_status_transitions = current_job_description.get( - "SecondaryStatusTransitions" - ) + current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions") if ( current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0 @@ -247,9 +243,7 @@ def secondary_training_status_message(job_description, prev_description): return "" prev_description_secondary_transitions = ( - prev_description.get("SecondaryStatusTransitions") - if prev_description is not None - else None + prev_description.get("SecondaryStatusTransitions") if prev_description is not None else None ) prev_transitions_num = ( len(prev_description["SecondaryStatusTransitions"]) @@ -449,9 +443,7 @@ def repack_model( with tarfile.open(tmp_model_path, mode="w:gz") as t: t.add(model_dir, arcname=os.path.sep) - _save_model( - repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key - ) + _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key) def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): @@ -459,14 +451,10 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key): if repacked_model_uri.lower().startswith("s3://"): url = parse.urlparse(repacked_model_uri) bucket, key = url.netloc, url.path.lstrip("/") - new_key = key.replace( - os.path.basename(key), os.path.basename(repacked_model_uri) - ) + new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri)) settings = ( - sagemaker_session.settings - if sagemaker_session is not None - else SessionSettings() + sagemaker_session.settings if sagemaker_session is not None else SessionSettings() ) encrypt_artifact = settings.encrypt_repacked_artifacts @@ -513,9 +501,7 @@ def _create_or_update_code_dir( for dependency in dependencies: lib_dir = os.path.join(code_dir, "lib") if os.path.isdir(dependency): - shutil.copytree( - dependency, os.path.join(lib_dir, os.path.basename(dependency)) - ) + shutil.copytree(dependency, os.path.join(lib_dir, os.path.basename(dependency))) else: if not os.path.exists(lib_dir): os.mkdir(lib_dir) @@ -745,11 +731,7 @@ def get_data_bucket(self, region_requested=None): """ config = self.fetch_data_config() - region = ( - region_requested - if region_requested - else self.sagemaker_session.boto_region_name - ) + region = region_requested if region_requested else self.sagemaker_session.boto_region_name return config[region] if region in config.keys() else config["default"] @@ -855,9 +837,7 @@ def construct_container_object( return obj -def pop_out_unused_kwarg( - arg_name: str, kwargs: dict, override_val: Optional[str] = None -): +def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None): """Pop out the unused key-word argument and give a warning. Args: diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 55f89acf87..f49af5a938 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -42,9 +42,7 @@ "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": ["s3:PutObject", "s3:PutObjectAcl"], "Resource": "arn:aws:s3:::{bucket_name}-{region_name}/*", - "Condition": { - "StringEquals": {"s3:x-amz-acl": "bucket-owner-full-control"} - }, + "Condition": {"StringEquals": {"s3:x-amz-acl": "bucket-owner-full-control"}}, }, { "Sid": "FeatureStoreOfflineStoreS3BucketPolicy", @@ -160,9 +158,7 @@ def test_create_feature_store_online_only( feature_group_name, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -187,9 +183,7 @@ def test_create_feature_store( record, create_table_ddl, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -226,9 +220,9 @@ def test_create_feature_store( output_location=f"{offline_store_s3_uri}/query_results", ) athena_query.wait() - assert "SUCCEEDED" == athena_query.get_query_execution().get( - "QueryExecution" - ).get("Status").get("State") + assert "SUCCEEDED" == athena_query.get_query_execution().get("QueryExecution").get( + "Status" + ).get("State") df = athena_query.as_dataframe() print(f"Found {df.shape[0]} records.") time.sleep(60) @@ -256,9 +250,7 @@ def test_create_feature_group_iceberg_table_format( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -272,9 +264,7 @@ def test_create_feature_group_iceberg_table_format( ) _wait_for_feature_group_create(feature_group) - table_format = ( - feature_group.describe().get("OfflineStoreConfig").get("TableFormat") - ) + table_format = feature_group.describe().get("OfflineStoreConfig").get("TableFormat") assert table_format == "Iceberg" @@ -285,9 +275,7 @@ def test_create_feature_group_glue_table_format( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -301,9 +289,7 @@ def test_create_feature_group_glue_table_format( ) _wait_for_feature_group_create(feature_group) - table_format = ( - feature_group.describe().get("OfflineStoreConfig").get("TableFormat") - ) + table_format = feature_group.describe().get("OfflineStoreConfig").get("TableFormat") assert table_format == "Glue" @@ -314,9 +300,7 @@ def test_update_feature_group( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -334,9 +318,7 @@ def test_update_feature_group( feature_group.update(new_features) _wait_for_feature_group_update(feature_group) feature_definitions = feature_group.describe().get("FeatureDefinitions") - assert any( - [True for elem in feature_definitions if new_feature_name in elem.values()] - ) + assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) def test_feature_metadata( @@ -346,9 +328,7 @@ def test_feature_metadata( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -397,9 +377,7 @@ def test_ingest_without_string_feature( offline_store_s3_uri, pandas_data_frame_without_string, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame_without_string) with cleanup_feature_group(feature_group): @@ -427,9 +405,7 @@ def test_ingest_multi_process( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -480,9 +456,7 @@ def test_get_feature_group_with_role_region( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -520,9 +494,7 @@ def test_get_feature_group_with_session( offline_store_s3_uri, pandas_data_frame, ): - feature_group = FeatureGroup( - name=feature_group_name, sagemaker_session=feature_store_session - ) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) feature_group.load_feature_definitions(data_frame=pandas_data_frame) with cleanup_feature_group(feature_group): @@ -562,6 +534,4 @@ def cleanup_feature_group(feature_group: FeatureGroup): try: feature_group.delete() except Exception: - raise RuntimeError( - f"Failed to delete feature group with name {feature_group.name}" - ) + raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") diff --git a/tests/unit/sagemaker/feature_store/test_feature_definition.py b/tests/unit/sagemaker/feature_store/test_feature_definition.py index 3c8a4b9d81..a9c4a10b2a 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_definition.py +++ b/tests/unit/sagemaker/feature_store/test_feature_definition.py @@ -30,9 +30,7 @@ def ordered(obj): def test_feature_definition(): - definition = FeatureDefinition( - feature_name="MyFeature", feature_type=FeatureTypeEnum.INTEGRAL - ) + definition = FeatureDefinition(feature_name="MyFeature", feature_type=FeatureTypeEnum.INTEGRAL) assert ordered(definition.to_dict()) == ordered( { "FeatureName": "MyFeature", diff --git a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py index 377ec090da..97c7442562 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_group_utils.py +++ b/tests/unit/sagemaker/feature_store/test_feature_group_utils.py @@ -45,9 +45,7 @@ def sagemaker_session_mock(): def test_convert_unsupported_types_to_supported(sagemaker_session_mock): - feature_group = FeatureGroup( - name="FailedGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 180651e3bc..b1ad630461 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -96,9 +96,7 @@ def create_table_ddl(): def test_feature_store_create( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -126,9 +124,7 @@ def test_feature_store_create( def test_feature_store_create_iceberg_table_format( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -159,9 +155,7 @@ def test_feature_store_create_iceberg_table_format( def test_feature_store_create_glue_table_format( sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri ): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=s3_uri, @@ -192,9 +186,7 @@ def test_feature_store_create_glue_table_format( def test_feature_store_create_online_only( sagemaker_session_mock, role_arn, feature_group_dummy_definitions ): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions feature_group.create( s3_uri=False, @@ -216,9 +208,7 @@ def test_feature_store_create_online_only( def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.delete() sagemaker_session_mock.delete_feature_group.assert_called_with( feature_group_name="MyFeatureGroup" @@ -226,9 +216,7 @@ def test_feature_store_delete(sagemaker_session_mock): def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.describe() sagemaker_session_mock.describe_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", next_token=None @@ -236,9 +224,7 @@ def test_feature_store_describe(sagemaker_session_mock): def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.update(feature_group_dummy_definitions) sagemaker_session_mock.update_feature_group.assert_called_with( feature_group_name="MyFeatureGroup", @@ -247,9 +233,7 @@ def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_defini def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) parameter_additions = [FeatureParameter(key="key1", value="value1")] parameter_removals = ["key2"] @@ -267,9 +251,7 @@ def test_feature_metadata_update(sagemaker_session_mock): parameter_additions=[pa.to_dict() for pa in parameter_additions], parameter_removals=parameter_removals, ) - feature_group.update_feature_metadata( - feature_name="Feature1", description="TestDescription" - ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") sagemaker_session_mock.update_feature_metadata.assert_called_with( feature_group_name="MyFeatureGroup", feature_name="Feature1", @@ -280,9 +262,7 @@ def test_feature_metadata_update(sagemaker_session_mock): def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.describe_feature_metadata(feature_name="Feature1") sagemaker_session_mock.describe_feature_metadata.assert_called_with( feature_group_name="MyFeatureGroup", feature_name="Feature1" @@ -290,9 +270,7 @@ def test_feature_metadata_describe(sagemaker_session_mock): def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup( - name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.put_record(record=[]) sagemaker_session_mock.put_record.assert_called_with( feature_group_name="MyFeatureGroup", record=[] @@ -300,9 +278,7 @@ def test_put_record(sagemaker_session_mock): def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup( - name="SomeGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), @@ -322,9 +298,7 @@ def test_load_feature_definition(sagemaker_session_mock): def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup( - name="FailedGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) df = pd.DataFrame( { "float": pd.Series([2.0], dtype="float64"), @@ -334,16 +308,11 @@ def test_load_feature_definition_unsupported_types(sagemaker_session_mock): ) with pytest.raises(ValueError) as error: feature_group.load_feature_definitions(data_frame=df) - assert ( - "Failed to infer Feature type based on dtype object for column object." - in str(error) - ) + assert "Failed to infer Feature type based on dtype object for column object." in str(error) def test_ingest_zero_processes(): - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) df = Mock() with pytest.raises(RuntimeError) as error: feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) @@ -352,9 +321,7 @@ def test_ingest_zero_processes(): def test_ingest_zero_workers(): - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) df = Mock() with pytest.raises(RuntimeError) as error: feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) @@ -363,19 +330,13 @@ def test_ingest_zero_workers(): @patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest( - ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock -): +def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( fs_runtime_client_config_mock ) - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) - df = pd.DataFrame( - dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)) - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) mock_ingestion_manager_instance = Mock() ingestion_manager_init.return_value = mock_ingestion_manager_instance @@ -401,12 +362,8 @@ def test_ingest_with_profile_name( fs_runtime_client_config_mock ) - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) - df = pd.DataFrame( - dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)) - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) mock_ingestion_manager_instance = Mock() ingestion_manager_init.return_value = mock_ingestion_manager_instance @@ -438,9 +395,7 @@ def test_as_hive_ddl_with_default_values( sagemaker_session_mock.account_id.return_value = "1234" sagemaker_session_mock.boto_session.region_name = "us-west-2" - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions assert ( create_table_ddl.format( @@ -454,9 +409,7 @@ def test_as_hive_ddl_with_default_values( ) -def test_as_hive_ddl( - create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock -): +def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): sagemaker_session_mock.describe_feature_group.return_value = { "OfflineStoreConfig": { "S3StorageConfig": { @@ -468,9 +421,7 @@ def test_as_hive_ddl( sagemaker_session_mock.account_id.return_value = "1234" sagemaker_session_mock.boto_session.region_name = "us-west-2" - feature_group = FeatureGroup( - name="MyGroup", sagemaker_session=sagemaker_session_mock - ) + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) feature_group.feature_definitions = feature_group_dummy_definitions assert create_table_ddl.format( database="MyDatabase", @@ -494,9 +445,7 @@ def test_ingestion_manager_run_success(): ) manager.run(df) - manager._run_multi_process.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) + manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) @patch( @@ -588,9 +537,7 @@ def query(sagemaker_session_mock): def test_athena_query_run(sagemaker_session_mock, query): WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = { - "QueryExecutionId": "query_id" - } + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} query.run( query_string="query", output_location="s3://some-bucket/some-path", @@ -612,17 +559,13 @@ def test_athena_query_run(sagemaker_session_mock, query): def test_athena_query_wait(sagemaker_session_mock, query): query._current_query_execution_id = "query_id" query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with( - query_execution_id="query_id" - ) + sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") def test_athena_query_get_query_execution(sagemaker_session_mock, query): query._current_query_execution_id = "query_id" query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with( - query_execution_id="query_id" - ) + sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") @patch("tempfile.gettempdir", Mock(return_value="tmp")) From 0f5cf1824c0b116c9b218c803f3b94a85e09fd45 Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:39 +0000 Subject: [PATCH 026/526] prepare release v2.119.0 --- CHANGELOG.md | 28 ++++++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95e4a7b9cf..b8b3155231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## v2.119.0 (2022-12-03) + +### Features + + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * fix PySparkProcessor __init__ params type + * type hint of PySparkProcessor __init__ + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + ## v2.118.0 (2022-12-01) ### Features diff --git a/VERSION b/VERSION index 34d47b7f52..23fe2bf317 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.118.1.dev0 +2.119.0 From f1f0013dc0375aa22805b3a59b82cd2b1a08d40a Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:41 +0000 Subject: [PATCH 027/526] update development version to v2.119.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 23fe2bf317..dda4128cf2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.0 +2.119.1.dev0 From bb4b6897971a4e5ae0cbde948ef1682a64232b41 Mon Sep 17 00:00:00 2001 From: Radhika Bhat <78102284+RadhikaB-97@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:06:58 -0800 Subject: [PATCH 028/526] feature: Add DXB region to frameworks by DLC (#3387) * Add DXB region * Remove change from neuron * Adding DXB to TF 2.1.0 and 2.1.1 --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../huggingface-training-compiler.json | 3 + .../image_uri_config/huggingface.json | 31 +++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ src/sagemaker/image_uri_config/pytorch.json | 28 ++++++++ .../image_uri_config/tensorflow.json | 65 +++++++++++++++++++ 6 files changed, 152 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3cc488c55d..0963520e02 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -26,6 +26,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -56,6 +57,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -86,6 +88,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -116,6 +119,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -146,6 +150,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -176,6 +181,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -217,6 +223,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -250,6 +257,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -283,6 +291,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -316,6 +325,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -349,6 +359,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -382,6 +393,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index e771e2a548..482264b773 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -60,6 +60,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -89,6 +90,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -123,6 +125,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 317c17030a..e995c6e8ea 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -38,6 +38,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -70,6 +71,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -108,6 +110,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -140,6 +143,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -180,6 +184,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -213,6 +218,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -246,6 +252,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -279,6 +286,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -320,6 +328,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -353,6 +362,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -386,6 +396,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -419,6 +430,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -458,6 +470,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -491,6 +504,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -530,6 +544,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -563,6 +578,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -602,6 +618,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -635,6 +652,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -687,6 +705,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -720,6 +739,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -753,6 +773,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -794,6 +815,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -827,6 +849,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -860,6 +883,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -893,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -932,6 +957,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +991,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1004,6 +1031,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1037,6 +1065,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1076,6 +1105,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1109,6 +1139,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 12bc40fccf..14bb74f6a6 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -245,6 +245,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -277,6 +278,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -309,6 +311,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -341,6 +344,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -373,6 +377,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -632,6 +637,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -664,6 +670,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -696,6 +703,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -728,6 +736,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -760,6 +769,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -865,6 +875,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -897,6 +908,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -929,6 +941,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 3bf8016ba8..e1de6ca663 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -195,6 +195,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -230,6 +231,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -264,6 +266,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -298,6 +301,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -333,6 +337,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -368,6 +373,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -403,6 +409,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -438,6 +445,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -472,6 +480,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -506,6 +515,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -540,6 +550,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -574,6 +585,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -608,6 +620,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -642,6 +655,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -879,6 +893,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -914,6 +929,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -949,6 +965,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -983,6 +1000,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1018,6 +1036,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1053,6 +1072,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1088,6 +1108,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1123,6 +1144,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1157,6 +1179,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1191,6 +1214,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1225,6 +1249,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1259,6 +1284,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1293,6 +1319,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1327,6 +1354,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 0122dcd3ca..bb05682f67 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -154,6 +154,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -185,6 +186,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -216,6 +218,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -247,6 +250,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -401,6 +405,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -432,6 +437,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -463,6 +469,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -494,6 +501,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -525,6 +533,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -556,6 +565,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -587,6 +597,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -810,6 +821,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -841,6 +853,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -872,6 +885,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -903,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -934,6 +949,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -996,6 +1013,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1027,6 +1045,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1058,6 +1077,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1089,6 +1109,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1120,6 +1141,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1151,6 +1173,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1182,6 +1205,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1213,6 +1237,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1244,6 +1269,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1275,6 +1301,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1306,6 +1333,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1337,6 +1365,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1368,6 +1397,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1399,6 +1429,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1430,6 +1461,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1461,6 +1493,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1760,6 +1793,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1796,6 +1830,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1831,6 +1866,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1867,6 +1903,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1903,6 +1940,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1939,6 +1977,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1975,6 +2014,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2202,6 +2242,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2237,6 +2278,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2272,6 +2314,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2306,6 +2349,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2340,6 +2384,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2375,6 +2420,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2410,6 +2456,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2444,6 +2491,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2478,6 +2526,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2512,6 +2561,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2546,6 +2596,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2580,6 +2631,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2614,6 +2666,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2648,6 +2701,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2682,6 +2736,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2716,6 +2771,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2750,6 +2806,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2784,6 +2841,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2818,6 +2876,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2852,6 +2911,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2886,6 +2946,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2920,6 +2981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2954,6 +3016,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2988,6 +3051,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3022,6 +3086,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", From b68bcd9344deba8e3bedf7ccb0adb31498735b13 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Mon, 5 Dec 2022 14:11:34 -0800 Subject: [PATCH 029/526] fix: support idempotency for framework and spark processors (#3460) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 8 +- src/sagemaker/spark/processing.py | 37 +- src/sagemaker/workflow/utilities.py | 7 +- tests/data/spark/code/java/TestJarFile.jar | Bin 0 -> 1714 bytes .../hello-java-spark/HelloJavaSparkApp.jar | Bin 0 -> 1714 bytes .../unit/sagemaker/workflow/test_pipeline.py | 8 +- .../workflow/test_processing_step.py | 277 +++++++++++++- .../sagemaker/workflow/test_training_step.py | 354 +++++++++++++++--- .../sagemaker/workflow/test_transform_step.py | 8 + .../sagemaker/workflow/test_tuning_step.py | 58 +-- 10 files changed, 661 insertions(+), 96 deletions(-) create mode 100644 tests/data/spark/code/java/TestJarFile.jar create mode 100644 tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 308783578d..81e3d34b1d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -23,6 +23,7 @@ import logging from textwrap import dedent from typing import Dict, List, Optional, Union +from copy import copy import attr @@ -1830,14 +1831,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput # a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426 if inputs is None: inputs = [] - inputs.append( + + # make a shallow copy of user inputs + patched_inputs = copy(inputs) + patched_inputs.append( ProcessingInput( input_name="code", source=s3_payload, destination="/opt/ml/processing/input/code/", ) ) - return inputs + return patched_inputs def _set_entrypoint(self, command, user_script_name): """Framework processor override for setting processing job entrypoint. diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index dc3d26a355..912bc90d80 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -30,6 +30,7 @@ from enum import Enum from io import BytesIO from urllib.parse import urlparse +from copy import copy from typing import Union, List, Dict, Optional @@ -279,6 +280,10 @@ def run( def _extend_processing_args(self, inputs, outputs, **kwargs): """Extends processing job args such as inputs.""" + # make a shallow copy of user outputs + outputs = outputs or [] + extended_outputs = copy(outputs) + if kwargs.get("spark_event_logs_s3_uri"): spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri") self._validate_s3_uri(spark_event_logs_s3_uri) @@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): s3_upload_mode="Continuous", ) - outputs = outputs or [] - outputs.append(output) + extended_outputs.append(output) + + # make a shallow copy of user inputs + inputs = inputs or [] + extended_inputs = copy(inputs) if kwargs.get("configuration"): configuration = kwargs.get("configuration") self._validate_configuration(configuration) - inputs = inputs or [] - inputs.append(self._stage_configuration(configuration)) + extended_inputs.append(self._stage_configuration(configuration)) - return inputs, outputs + return ( + extended_inputs if extended_inputs else None, + extended_outputs if extended_outputs else None, + ) def start_history_server(self, spark_event_logs_s3_uri=None): """Starts a Spark history server. @@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): outputs: Processing outputs. kwargs: Additional keyword arguments passed to `super()`. """ + + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + self.command = [_SparkProcessorBase._default_command] extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_py_files"), FileType.PYTHON + extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_jars"), FileType.JAR @@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): else: raise ValueError("submit_class is required") + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_jars"), FileType.JAR + extended_inputs, kwargs.get("submit_jars"), FileType.JAR ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_files"), FileType.FILE diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 89d7c5dfd9..08c170d424 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str: if isinstance(step, ProcessingStep) and step.step_args: kwargs = step.step_args.func_kwargs source_dir = kwargs.get("source_dir") + submit_class = kwargs.get("submit_class") dependencies = get_processing_dependencies( [ kwargs.get("dependencies"), kwargs.get("submit_py_files"), - kwargs.get("submit_class"), + [submit_class] if submit_class else None, kwargs.get("submit_jars"), kwargs.get("submit_files"), ] @@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str] str: A hash string representing the unique code artifact(s) for the step """ - # FrameworkProcessor + # If FrameworkProcessor contains source_dir if source_dir: source_dir_url = urlparse(source_dir) if source_dir_url.scheme == "" or source_dir_url.scheme == "file": @@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments): """ chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs) - if chained_args: + if isinstance(chained_args, _StepArguments): execute_job_functions(chained_args) diff --git a/tests/data/spark/code/java/TestJarFile.jar b/tests/data/spark/code/java/TestJarFile.jar new file mode 100644 index 0000000000000000000000000000000000000000..d528331d557da00908e31c46b2a0dd3dc250a2bf GIT binary patch literal 1714 zcmWIWW@Zs#;Nak32&_&EWk3R)3@i-3t|5-Po_=on|4uP5Ff#;rvvYt{FhP|C;M6Pv zQ~}rQ>*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I Date: Mon, 5 Dec 2022 18:18:10 -0600 Subject: [PATCH 030/526] feature: Update registries with new region account number mappings. (#3492) --- src/sagemaker/image_uri_config/autogluon.json | 18 ++++ .../image_uri_config/huggingface-neuron.json | 3 + .../image_uri_config/huggingface.json | 39 +++++++ src/sagemaker/image_uri_config/mxnet.json | 24 +++++ src/sagemaker/image_uri_config/pytorch.json | 54 ++++++++++ .../image_uri_config/tensorflow.json | 102 ++++++++++++++++++ 6 files changed, 240 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 0963520e02..3a9f02142c 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -210,6 +210,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -217,11 +218,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -244,6 +247,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -251,11 +255,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -278,6 +284,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -285,11 +292,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -312,6 +321,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -319,11 +329,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -346,6 +358,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -353,11 +366,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -380,6 +395,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -387,11 +403,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 1e2246cb11..47d6dbd1dc 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -15,17 +15,20 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index e995c6e8ea..5b98fc0d02 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -692,6 +692,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -699,11 +700,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -726,6 +729,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -733,11 +737,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -760,6 +766,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -767,8 +774,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -802,6 +811,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -809,11 +819,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -836,6 +848,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -843,11 +856,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -870,6 +885,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -877,8 +893,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +922,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +930,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -944,6 +965,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -951,11 +973,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -978,6 +1002,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -985,8 +1010,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1018,6 +1045,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1025,11 +1053,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1052,6 +1082,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1059,8 +1090,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1092,6 +1125,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1099,11 +1133,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1126,6 +1162,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1133,8 +1170,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 14bb74f6a6..8d8733e480 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -624,6 +624,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -631,11 +632,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -657,6 +660,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -664,11 +668,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -690,6 +696,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -697,11 +704,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -723,6 +732,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -730,11 +740,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -756,6 +768,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -763,11 +776,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -862,6 +877,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -869,11 +885,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -895,6 +913,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -902,11 +921,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -928,6 +949,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -935,11 +957,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index e1de6ca663..18a382e591 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -17,6 +17,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -25,7 +26,9 @@ "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-north-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -39,8 +42,11 @@ "registries": { "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-3": "907027046896", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -182,6 +188,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -189,11 +196,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -218,6 +227,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -225,11 +235,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -253,6 +265,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -260,11 +273,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -288,6 +303,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -295,11 +311,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -324,6 +342,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -331,11 +350,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -360,6 +381,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -367,11 +389,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -396,6 +420,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -403,11 +428,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -432,6 +459,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -439,11 +467,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -467,6 +497,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -474,11 +505,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -502,6 +535,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -509,11 +543,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -537,6 +573,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -544,11 +581,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -572,6 +611,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -579,11 +619,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -607,6 +649,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -614,11 +657,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -642,6 +687,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -649,11 +695,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -677,6 +725,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -684,11 +733,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -721,6 +772,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -728,11 +780,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index bb05682f67..a0f2bba014 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -141,6 +141,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -148,12 +149,14 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "eu-south-2": "503227376785", "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -173,6 +176,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -180,8 +184,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -205,6 +211,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -212,8 +219,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -237,6 +246,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -244,8 +254,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -392,6 +404,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -399,8 +412,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -424,6 +439,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -431,8 +447,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -456,6 +474,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -463,8 +482,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -488,6 +509,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -495,8 +517,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -520,6 +544,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -527,8 +552,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -552,6 +579,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -559,8 +587,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -584,6 +614,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -591,8 +622,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -808,6 +841,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -815,8 +849,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -840,6 +876,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -847,8 +884,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -872,6 +911,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -879,8 +919,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +946,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +954,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -936,6 +981,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -943,8 +989,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -968,6 +1016,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -975,8 +1024,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1000,6 +1051,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1007,8 +1059,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1032,6 +1086,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1039,8 +1094,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1064,6 +1121,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1071,8 +1129,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1096,6 +1156,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1103,8 +1164,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1128,6 +1191,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1135,8 +1199,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1160,6 +1226,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1167,8 +1234,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1192,6 +1261,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1199,8 +1269,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1224,6 +1296,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1231,8 +1304,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1256,6 +1331,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1263,8 +1339,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1288,6 +1366,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1295,8 +1374,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1320,6 +1401,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1327,8 +1409,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1352,6 +1436,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1359,8 +1444,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1384,6 +1471,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1391,8 +1479,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1416,6 +1506,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1423,8 +1514,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1448,6 +1541,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1455,8 +1549,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1480,6 +1576,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1487,8 +1584,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1587,6 +1686,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1594,11 +1694,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", From 767da0afc5cfb11eb96b324debf9a310abaafbcc Mon Sep 17 00:00:00 2001 From: Loki Date: Wed, 7 Dec 2022 06:06:34 +0530 Subject: [PATCH 031/526] feature: Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 (#3500) Co-authored-by: Ubuntu --- src/sagemaker/fw_utils.py | 2 +- .../pytorch-training-compiler.json | 41 ++ src/sagemaker/image_uris.py | 2 +- src/sagemaker/pytorch/__init__.py | 2 + src/sagemaker/pytorch/estimator.py | 60 +- .../pytorch/training_compiler/__init__.py | 0 .../pytorch/training_compiler/config.py | 151 +++++ tests/conftest.py | 1 + tests/data/huggingface_byoc/requirements.txt | 2 + tests/data/huggingface_byoc/run_glue.py | 568 ++++++++++++++++ tests/data/huggingface_byoc/train/dummy.csv | 1 + tests/integ/__init__.py | 2 +- tests/integ/test_training_compiler.py | 50 +- .../test_pytorch_compiler.py | 616 ++++++++++++++++++ 14 files changed, 1467 insertions(+), 31 deletions(-) create mode 100644 src/sagemaker/image_uri_config/pytorch-training-compiler.json create mode 100644 src/sagemaker/pytorch/training_compiler/__init__.py create mode 100644 src/sagemaker/pytorch/training_compiler/config.py create mode 100644 tests/data/huggingface_byoc/requirements.txt create mode 100644 tests/data/huggingface_byoc/run_glue.py create mode 100644 tests/data/huggingface_byoc/train/dummy.csv create mode 100644 tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index d82d3596ac..5efe530396 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -493,7 +493,7 @@ def framework_name_from_image(image_uri): # We must support both the legacy and current image name format. name_pattern = re.compile( r"""^(?:sagemaker(?:-rl)?-)? - (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost + (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost |huggingface-tensorflow|huggingface-pytorch |huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)? (scriptmode|training)? diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json new file mode 100644 index 0000000000..892ff4237d --- /dev/null +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -0,0 +1,41 @@ +{ + "training": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.12": "1.12.0" + }, + "versions": { + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + } + } + } +} diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 7d1d3bd835..c42ce02188 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -146,7 +146,7 @@ def retrieve( tolerate_deprecated_model, ) - if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK): + if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): final_image_scope = image_scope config = _config_for_framework_and_scope( framework + "-training-compiler", final_image_scope, accelerator_type diff --git a/src/sagemaker/pytorch/__init__.py b/src/sagemaker/pytorch/__init__.py index cac5f94b9a..e2d14f4163 100644 --- a/src/sagemaker/pytorch/__init__.py +++ b/src/sagemaker/pytorch/__init__.py @@ -16,3 +16,5 @@ from sagemaker.pytorch.estimator import PyTorch # noqa: F401 from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401 from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401 + +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401 diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 686de4a78c..29e254662f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -28,6 +28,7 @@ ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable @@ -51,7 +52,8 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, - **kwargs + compiler_config: Optional[TrainingCompilerConfig] = None, + **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -208,6 +210,31 @@ def __init__( To learn more, see `Training with parameter servers `_. + **To enable distributed training with + `SageMaker Training Compiler `_ + for PyTorch:** + + .. code:: python + + { + "pytorchxla": { + "enabled": True + } + } + + To learn more, see `SageMaker Training Compiler + `_ + in the *Amazon SageMaker Developer Guide*. + + .. note:: + + When you use this PyTorch XLA option for distributed training strategy, + you must add the ``compiler_config`` parameter and activate SageMaker + Training Compiler. + + compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): + Configures SageMaker Training Compiler to accelerate training. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -250,6 +277,25 @@ def __init__( self.distribution = distribution or {} + if compiler_config is not None: + if not isinstance(compiler_config, TrainingCompilerConfig): + error_string = ( + f"Expected instance of type {TrainingCompilerConfig}" + f"for argument compiler_config. " + f"Instead got {type(compiler_config)}" + ) + raise ValueError(error_string) + if compiler_config: + compiler_config.validate(self) + elif distribution is not None and "pytorchxla" in distribution: + raise ValueError( + "Distributed training through PyTorch XLA is currently only supported " + "when SageMaker Training Compiler is enabled. To learn more, " + "see Enable SageMaker Training Compiler at " + "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html." + ) + self.compiler_config = compiler_config + def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training @@ -289,6 +335,12 @@ def hyperparameters(self): hyperparameters.update( EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) ) + if self.compiler_config: + training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) + ) + return hyperparameters def create_model( @@ -299,7 +351,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. @@ -350,7 +402,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod @@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) image_uri = init_params.pop("image_uri") framework, py_version, tag, _ = framework_name_from_image(image_uri) + if framework: + framework = framework.split("-")[0] if tag is None: framework_version = None diff --git a/src/sagemaker/pytorch/training_compiler/__init__.py b/src/sagemaker/pytorch/training_compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/pytorch/training_compiler/config.py b/src/sagemaker/pytorch/training_compiler/config.py new file mode 100644 index 0000000000..7faf8acbbd --- /dev/null +++ b/src/sagemaker/pytorch/training_compiler/config.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Configuration for the SageMaker Training Compiler.""" +from __future__ import absolute_import +import logging +from typing import Union +from packaging.specifiers import SpecifierSet +from packaging.version import Version + +from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable + +logger = logging.getLogger(__name__) + + +class TrainingCompilerConfig(BaseConfig): + """The SageMaker Training Compiler configuration class.""" + + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + ] + + def __init__( + self, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, + ): + """This class initializes a ``TrainingCompilerConfig`` instance. + + `Amazon SageMaker Training Compiler + `_ + is a feature of SageMaker Training + and speeds up training jobs by optimizing model execution graphs. + + You can compile PyTorch models + by passing the object of this configuration class to the ``compiler_config`` + parameter of the :class:`~sagemaker.pytorch.PyTorch` + estimator. + + Args: + enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker + Training Compiler. The default is ``True``. + debug (bool or PipelineVariable): Optional. Whether to dump detailed logs + for debugging. This comes with a potential performance slowdown. + The default is ``False``. + + **Example**: The following code shows the basic usage of the + :class:`sagemaker.pytorch.TrainingCompilerConfig()` class + to run a PyTorch training job with the compiler. + + .. code-block:: python + + from sagemaker.pytorch import PyTorch, TrainingCompilerConfig + + pytorch_estimator=PyTorch( + ... + compiler_config=TrainingCompilerConfig() + ) + + .. seealso:: + + For more information about how to enable SageMaker Training Compiler + for various training settings such as distributed training, + see `Enable SageMaker Training Compiler + `_ + in the `Amazon SageMaker Training Compiler developer guide + `_. + + """ + + super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) + + @classmethod + def validate( + cls, + estimator, + ): + """Checks if SageMaker Training Compiler is configured correctly. + + Args: + estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object. + If SageMaker Training Compiler is enabled, it will validate whether + the estimator is configured to be compatible with Training Compiler. + + Raises: + ValueError: Raised if the requested configuration is not compatible + with SageMaker Training Compiler. + """ + + super(TrainingCompilerConfig, cls).validate(estimator) + + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the PyTorch training job " + "with SageMaker Training Compiler enabled: " + "framework_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if estimator.distribution: + pt_xla_present = "pytorchxla" in estimator.distribution + pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) + if pt_xla_enabled: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet("< 1.12"): + error_helper_string = ( + "Distribution mechanism 'pytorchxla' is currently only supported for " + "PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received framework_version={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.framework_version)) + if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA: + logger.warning( + "Consider using instances with EFA support when " + "training with PyTorch >= 1.12 and SageMaker Training Compiler " + "enabled. SageMaker Training Compiler leverages EFA to provide better " + "performance for distributed training." + ) + if not pt_xla_present: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + error_helper_string = ( + "'pytorchxla' is the only distribution mechanism currently supported " + "for PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received distribution={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.distribution)) + elif estimator.instance_count and estimator.instance_count > 1: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + logger.warning( + "Consider setting 'distribution' to 'pytorchxla' for distributed " + "training with PyTorch >= 1.12 and SageMaker Training Compiler enabled." + ) diff --git a/tests/conftest.py b/tests/conftest.py index e92d98112b..f6682ebb8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ "neo_pytorch", "neo_tensorflow", "pytorch", + "pytorch_training_compiler", "ray_pytorch", "ray_tensorflow", "sklearn", diff --git a/tests/data/huggingface_byoc/requirements.txt b/tests/data/huggingface_byoc/requirements.txt new file mode 100644 index 0000000000..462542f1c1 --- /dev/null +++ b/tests/data/huggingface_byoc/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py new file mode 100644 index 0000000000..1060398fa4 --- /dev/null +++ b/tests/data/huggingface_byoc/run_glue.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.5.0") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the test data."} + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError( + "Unknown task, you should pick one in " + ",".join(task_to_keys.keys()) + ) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [ + name for name in datasets["train"].column_names if name != "label" + ] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) + if sentence2_key is None + else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + datasets = datasets.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in datasets and "validation_matched" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets[ + "validation_matched" if data_args.task_name == "mnli" else "validation" + ] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + + if ( + training_args.do_predict + or data_args.task_name is not None + or data_args.test_file is not None + ): + if "test" not in datasets and "test_matched" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_val_samples = ( + data_args.max_val_samples + if data_args.max_val_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = ( + np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + ) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/data/huggingface_byoc/train/dummy.csv b/tests/data/huggingface_byoc/train/dummy.csv new file mode 100644 index 0000000000..fb1539d552 --- /dev/null +++ b/tests/data/huggingface_byoc/train/dummy.csv @@ -0,0 +1 @@ +# dummy data \ No newline at end of file diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 00ed09577b..9133fc8904 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -158,7 +158,7 @@ "ap-northeast-1", "eu-central-1", ] -# TODO: SM Training Compiler team to add all supported regions. + TRAINING_COMPILER_SUPPORTED_REGIONS = [ "af-south-1", "ap-east-1", diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 67de050ed1..724cd8890c 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -20,6 +20,8 @@ from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig from sagemaker.tensorflow import TensorFlow from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch import TrainingCompilerConfig as PTTrainingCompilerConfig from tests import integ from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES @@ -48,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory): key_prefix="Imagenet/TFRecords/validation", ) train_input = sagemaker_session.upload_data( - path=local_path, - key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val", + path=local_path, key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val" ) return train_input @@ -84,8 +85,8 @@ def skip_if_incompatible(gpu_instance_type, request): @pytest.mark.parametrize( "gpu_instance_type,instance_count", [ - ("ml.p3.2xlarge", 1), - ("ml.p3.16xlarge", 2), + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), ], ) def test_huggingface_pytorch( @@ -129,27 +130,32 @@ def test_huggingface_pytorch( hf.fit(huggingface_dummy_dataset) -@pytest.mark.release -def test_huggingface_pytorch_release( +@pytest.mark.parametrize( + "gpu_instance_type,instance_count", + [ + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), + ], +) +def test_pytorch( sagemaker_session, gpu_instance_type, - huggingface_training_compiler_latest_version, - huggingface_training_compiler_pytorch_latest_version, + instance_count, + pytorch_training_compiler_latest_version, huggingface_dummy_dataset, ): """ - Test the HuggingFace estimator with PyTorch + Test the PyTorch estimator """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, "huggingface") - hf = HuggingFace( + hf = PyTorch( py_version="py38", - entry_point=os.path.join(data_path, "run_glue.py"), + source_dir=os.path.join(DATA_DIR, "huggingface_byoc"), + entry_point="run_glue.py", role="SageMakerRole", - transformers_version=huggingface_training_compiler_latest_version, - pytorch_version=huggingface_training_compiler_pytorch_latest_version, - instance_count=1, + framework_version=pytorch_training_compiler_latest_version, + instance_count=instance_count, instance_type=gpu_instance_type, hyperparameters={ "model_name_or_path": "distilbert-base-cased", @@ -163,7 +169,8 @@ def test_huggingface_pytorch_release( }, sagemaker_session=sagemaker_session, disable_profiler=True, - compiler_config=HFTrainingCompilerConfig(), + compiler_config=PTTrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None, ) hf.fit(huggingface_dummy_dataset) @@ -209,10 +216,7 @@ def test_huggingface_tensorflow( @pytest.mark.release def test_tensorflow( - sagemaker_session, - gpu_instance_type, - tensorflow_training_latest_version, - imagenet_val_set, + sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set ): """ Test the TensorFlow estimator @@ -264,8 +268,4 @@ def test_tensorflow( compiler_config=TFTrainingCompilerConfig(), ) - tf.fit( - inputs=imagenet_val_set, - logs=True, - wait=True, - ) + tf.fit(inputs=imagenet_val_set, logs=True, wait=True) diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py new file mode 100644 index 0000000000..0fe2402695 --- /dev/null +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -0,0 +1,616 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import logging + +import json +import os + +import pytest +from mock import MagicMock, Mock, patch, ANY +from packaging.version import Version + +from sagemaker import image_uris +from sagemaker.pytorch import PyTorch, TrainingCompilerConfig +from sagemaker.pytorch.model import PyTorchModel +from sagemaker.instance_group import InstanceGroup + +from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" +MODEL_DATA = "s3://some/data.tar.gz" +ENV = {"DUMMY_ENV_VAR": "dummy_value"} +TIMESTAMP = "2017-11-06-14:14:15.672" +TIME = 1510006209.073025 +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.p3.2xlarge" +IMAGE_URI = "pytorch" +JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) +ROLE = "Dummy" +REGION = "us-east-1" +GPU = "ml.p3.2xlarge" +SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"} +UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} + +EXPERIMENT_CONFIG = { + "ExperimentName": "exp", + "TrialName": "trial", + "TrialComponentDisplayName": "tc", +} + + +@pytest.fixture(scope="module") +def cpu_instance_type(): + return "ml.m5.xlarge" + + +@pytest.fixture(name="sagemaker_session", scope="function") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + return session + + +def _get_full_gpu_image_uri(version, instance_type, training_compiler_config): + return image_uris.retrieve( + "pytorch-training-compiler", + REGION, + version=version, + py_version="py38", + instance_type=instance_type, + image_scope="training", + container_version=None, + training_compiler_config=training_compiler_config, + ) + + +def _create_train_job(version, instance_type, training_compiler_config, instance_count=1): + return { + "image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, + } + ], + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": instance_type, + "InstanceCount": instance_count, + "VolumeSizeInGB": 30, + }, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-east-1"', + }, + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, + "environment": None, + "retry_strategy": None, + "experiment_config": EXPERIMENT_CONFIG, + "debugger_hook_config": { + "CollectionConfigurations": [], + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, + "profiler_rule_configs": [ + { + "RuleConfigurationName": "ProfilerReport-1510006209", + "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", + "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, + } + ], + "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + } + + +def test_unsupported_BYOC( + pytorch_training_compiler_version, +): + byoc = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + with pytest.raises(ValueError): + PyTorch( + image_uri=byoc, + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=cpu_instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +def test_unsupported_gpu_instance( + unsupported_gpu_instance_class, pytorch_training_compiler_version +): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=f"ml.{unsupported_gpu_instance_class}.xlarge", + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.") +def test_unsupported_framework_version(): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="99.99.99", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_python_2( + pytorch_training_compiler_version, +): + with pytest.raises(ValueError): + PyTorch( + py_version="py27", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_instance_group( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_groups=[ + InstanceGroup("ml.p3dn.24xlarge", "ml.p3dn.24xlarge", 16), + InstanceGroup("ml.p4d.24xlarge", "ml.p4d.24xlarge", 16), + ], + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_distribution( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + transformers_version="4.17", + pytorch_version="1.10", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"mpi": {"enabled": True}}, + ).fit() + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_pytorchxla_distribution( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=2, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2 + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_default_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_debug_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(debug=True) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + True + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_disable_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(enabled=False) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(enabled=False), + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + False + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@pytest.mark.parametrize( + ["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)] +) +def test_attach(sagemaker_session, compiler_enabled, debug_enabled): + training_image = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"trcomp"', + "training_steps": "100", + "sagemaker_region": '"us-east-1"', + TrainingCompilerConfig.HP_ENABLE_COMPILER: json.dumps(compiler_enabled), + TrainingCompilerConfig.HP_ENABLE_DEBUG: json.dumps(debug_enabled), + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.p3.2xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "trcomp", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "trcomp" + assert estimator.py_version == "py38" + assert estimator.framework_version == "1.12.0" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" + assert estimator.instance_count == 1 + assert estimator.max_run == 24 * 60 * 60 + assert estimator.input_mode == "File" + assert estimator.base_job_name == "trcomp" + assert estimator.output_path == "s3://place/output/trcomp" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps( + compiler_enabled + ) + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps( + debug_enabled + ) + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_pytorch_model_auto_infer_framework( + sagemaker_session, pytorch_training_compiler_version +): + + model_package_group_name = "test-pt-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + pt_model = PyTorchModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_training_compiler_version, + py_version="py38", + sagemaker_session=sagemaker_session, + ) + + pt_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": pytorch_training_compiler_version, + } + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) From d779d1b8296242eb15637e85272a1a50a7ee897b Mon Sep 17 00:00:00 2001 From: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> Date: Tue, 6 Dec 2022 16:37:16 -0800 Subject: [PATCH 032/526] feature: Add Neo image uri config for Pytorch 1.12 (#3507) --- .../image_uri_config/neo-pytorch.json | 36 ++++++++++++++++++- tests/data/pytorch_neo/code/inference.py | 4 +-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json index bd15a6450e..c46dd3de5d 100644 --- a/src/sagemaker/image_uri_config/neo-pytorch.json +++ b/src/sagemaker/image_uri_config/neo-pytorch.json @@ -11,7 +11,9 @@ "1.7.0": "1.7", "1.7.1": "1.7", "1.8.0": "1.8", - "1.8.1": "1.8" + "1.8.1": "1.8", + "1.12.0": "1.12", + "1.12.1": "1.12" }, "versions": { "1.4": { @@ -173,6 +175,38 @@ "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" + }, + "1.12": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "774647643957", + "ap-east-1": "110948597952", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-northeast-3": "925152966179", + "ap-south-1": "763008648453", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "ca-central-1": "464438896020", + "cn-north-1": "472730292857", + "cn-northwest-1": "474822919863", + "eu-central-1": "746233611703", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "me-south-1": "836785723513", + "sa-east-1": "756306329178", + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-gov-west-1": "263933020539", + "us-iso-east-1": "167761179201", + "us-isob-east-1": "406031935815", + "us-west-1": "710691900526", + "us-west-2": "301217895009" + }, + "repository": "sagemaker-inference-pytorch" } } } diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 5b89c2bebc..79fe66d716 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "model.pth" - model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) + # The compiled model is saved as "model.pth" or "model.pt" + model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl") From 83327fb9ef5eb5f44c9fd3f8925c7791576c9a37 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:15 +0000 Subject: [PATCH 033/526] prepare release v2.120.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8b3155231..71894ff29d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.120.0 (2022-12-07) + +### Features + + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + +### Bug Fixes and Other Changes + + * support idempotency for framework and spark processors + ## v2.119.0 (2022-12-03) ### Features diff --git a/VERSION b/VERSION index dda4128cf2..7de9d18b4e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.1.dev0 +2.120.0 From 5bffb04b78e8cd6422654008511aa61ca6f66efb Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:17 +0000 Subject: [PATCH 034/526] update development version to v2.120.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7de9d18b4e..73c4cd6968 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.0 +2.120.1.dev0 From b828396c55082bc5f06092be41555729d775874a Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Wed, 7 Dec 2022 20:58:37 +0530 Subject: [PATCH 035/526] feature: Algorithms Region Expansion OSU/DXB (#3508) Co-authored-by: Malav Shastri --- .../image_uri_config/blazingtext.json | 2 ++ .../factorization-machines.json | 2 ++ .../image_uri_config/forecasting-deepar.json | 2 ++ .../image-classification.json | 2 ++ .../image_uri_config/ipinsights.json | 2 ++ src/sagemaker/image_uri_config/kmeans.json | 2 ++ src/sagemaker/image_uri_config/knn.json | 2 ++ .../image_uri_config/linear-learner.json | 2 ++ src/sagemaker/image_uri_config/ntm.json | 2 ++ .../image_uri_config/object-detection.json | 2 ++ .../image_uri_config/object2vec.json | 2 ++ src/sagemaker/image_uri_config/pca.json | 2 ++ .../image_uri_config/randomcutforest.json | 2 ++ .../semantic-segmentation.json | 2 ++ src/sagemaker/image_uri_config/seq2seq.json | 2 ++ src/sagemaker/image_uri_config/sklearn.json | 14 ++++++++ src/sagemaker/image_uri_config/xgboost.json | 36 +++++++++++++++++++ tests/unit/sagemaker/image_uris/test_algos.py | 4 +++ .../unit/sagemaker/image_uris/test_sklearn.py | 2 ++ .../unit/sagemaker/image_uris/test_xgboost.py | 4 +++ 20 files changed, 90 insertions(+) diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index c588d65c73..ae4295c59a 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index 0f9930357f..8fb1895707 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 1acc96ed3e..e9beb7acb6 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "522234722520", "us-east-2": "566113047672", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "156387875391" diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 44ccb3f08d..61e037da08 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 4e56c149dc..cf3c70194f 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 952724ce11..e8e947f094 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 79b239966d..89e8ef6224 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index bb027284ab..606edd3791 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index 115264b346..16f9565405 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index 6a7ba03695..67b60fe587 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 39614d1273..b166cc96ff 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index 5f87d8528c..11982e2197 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index ae7a3574be..15dc84dfc5 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index 866dd606b4..f49bc43109 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index bb3daf93b6..87810ad09d 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index 7961fde282..4d093f5f62 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -24,10 +24,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -57,10 +59,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -90,10 +94,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -127,10 +133,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -160,10 +168,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -193,10 +203,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -230,10 +242,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index a809083c4a..946e78ecc4 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -25,10 +25,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -58,10 +60,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -91,10 +95,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -124,10 +130,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -155,10 +163,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -186,10 +196,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -217,10 +229,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -248,10 +262,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -286,10 +302,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -319,10 +337,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -352,10 +372,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -385,10 +407,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -416,10 +440,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -447,10 +473,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -478,10 +506,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -509,10 +539,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -544,10 +576,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -575,10 +609,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 454d375b4b..443727094a 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -68,10 +68,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107", @@ -155,10 +157,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index d0fcbdb300..8563753e8c 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -37,10 +37,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 78ab7e10ee..4d0f9f1dc3 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -35,10 +35,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", @@ -67,10 +69,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", From 357f73226c9c5fe651ea74169cafe585e1092ad0 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Wed, 7 Dec 2022 10:36:33 -0800 Subject: [PATCH 036/526] fix: Add constraints file for apache-airflow (#3510) --- requirements/extras/test_requirements.txt | 1 + tox.ini | 2 ++ 2 files changed, 3 insertions(+) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index b52f394bd0..fe93fd4d0e 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -11,6 +11,7 @@ contextlib2==21.6.0 awslogs==0.14.0 black==22.3.0 stopit==1.1.2 +# Update tox.ini to have correct version of airflow constraints file apache-airflow==2.4.1 apache-airflow-providers-amazon==4.0.0 attrs==22.1.0 diff --git a/tox.ini b/tox.ini index 2d5fdf0b40..3a398ca51d 100644 --- a/tox.ini +++ b/tox.ini @@ -73,6 +73,8 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" + pip install 'apache-airflow==2.4.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.4.1/constraints-3.10.txt" + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test] From a28d1dd129ecceb612d5e8927b6be72937711722 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Wed, 7 Dec 2022 19:14:12 -0800 Subject: [PATCH 037/526] fix: FrameworkProcessor S3 uploads (#3493) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 47 +++- .../data/pipeline/test_source_dir/script_1.py | 11 + .../data/pipeline/test_source_dir/script_2.py | 9 + .../pipeline/test_source_dir_2/script_2.py | 9 + .../workflow/test_processing_steps.py | 249 +++++++++++++++++- .../integ/sagemaker/workflow/test_workflow.py | 8 +- 6 files changed, 322 insertions(+), 11 deletions(-) create mode 100644 tests/data/pipeline/test_source_dir/script_1.py create mode 100644 tests/data/pipeline/test_source_dir/script_2.py create mode 100644 tests/data/pipeline/test_source_dir_2/script_2.py diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 81e3d34b1d..01d4361197 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1741,13 +1741,7 @@ def _pack_and_upload_code( raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") script = estimator.uploaded_code.script_name - s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(script), - desired_s3_uri=entrypoint_s3_uri, - kms_key=kms_key, - sagemaker_session=self.sagemaker_session, - ) - logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + s3_runproc_sh = self._create_and_upload_runproc(script, kms_key, entrypoint_s3_uri) return s3_runproc_sh, inputs, job_name @@ -1857,3 +1851,42 @@ def _set_entrypoint(self, command, user_script_name): ) ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] + + def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + """Create runproc shell script and upload to S3 bucket. + + If leveraging a pipeline session with optimized S3 artifact paths, + the runproc.sh file is hashed and uploaded to a separate S3 location. + + + Args: + user_script (str): Relative path to ```code``` in the source bundle + - e.g. 'process.py'. + kms_key (str): THe kms key used for encryption. + entrypoint_s3_uri (str): The S3 upload path for the runproc script. + """ + from sagemaker.workflow.utilities import _pipeline_config, hash_object + + if _pipeline_config and _pipeline_config.pipeline_name: + runproc_file_str = self._generate_framework_script(user_script) + runproc_file_hash = hash_object(runproc_file_str) + s3_uri = ( + f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" + f"code/{runproc_file_hash}/runproc.sh" + ) + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + runproc_file_str, + desired_s3_uri=s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + else: + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + self._generate_framework_script(user_script), + desired_s3_uri=entrypoint_s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + + return s3_runproc_sh diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py new file mode 100644 index 0000000000..4a427b1898 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_1.py @@ -0,0 +1,11 @@ +""" +Integ test file script_1.py +""" +import pathlib + +if __name__ == "__main__": + + print("writing file to /opt/ml/processing/test/test.py...") + pathlib.Path("/opt/ml/processing/test").mkdir(parents=True, exist_ok=True) + with open("/opt/ml/processing/test/test.py", "w") as f: + f.write('print("test...")') diff --git a/tests/data/pipeline/test_source_dir/script_2.py b/tests/data/pipeline/test_source_dir/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/data/pipeline/test_source_dir_2/script_2.py b/tests/data/pipeline/test_source_dir_2/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir_2/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py index 781bce85a7..238eff6123 100644 --- a/tests/integ/sagemaker/workflow/test_processing_steps.py +++ b/tests/integ/sagemaker/workflow/test_processing_steps.py @@ -17,15 +17,18 @@ import re import subprocess from datetime import datetime +from pathlib import Path import pytest from botocore.exceptions import WaiterError +from sagemaker.workflow.utilities import hash_files_or_dirs, hash_object from sagemaker import image_uris, get_execution_role, utils from sagemaker.dataset_definition import DatasetDefinition, AthenaDatasetDefinition -from sagemaker.processing import ProcessingInput, ProcessingOutput -from sagemaker.s3 import S3Uploader -from sagemaker.sklearn import SKLearnProcessor +from sagemaker.processing import ProcessingInput, ProcessingOutput, FrameworkProcessor +from sagemaker.s3 import S3Uploader, S3Downloader +from sagemaker.sklearn import SKLearnProcessor, SKLearn +from sagemaker.tensorflow import TensorFlow from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import ( @@ -379,6 +382,203 @@ def test_one_step_framework_processing_pipeline( pass +def test_multi_step_framework_processing_pipeline_same_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR = "/pipeline/test_source_dir" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + framework_processor_sk = FrameworkProcessor( + framework_version="1.0-1", + instance_type="ml.m5.xlarge", + instance_count=1, + base_job_name="my-job", + role=role, + estimator_cls=SKLearn, + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_sk.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_sk, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR, + "script_2.py", + ) + + # the same local source_dirs should have the same s3 paths + assert source_dir_1_s3_uri == source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_multi_step_framework_processing_pipeline_different_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR_1 = "/pipeline/test_source_dir" + SOURCE_DIR_2 = "/pipeline/test_source_dir_2" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR_1, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_tf.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR_2, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR_1, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR_2, + "script_2.py", + ) + + # different local source_dirs should have different s3 paths + assert source_dir_1_s3_uri != source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_one_step_pyspark_processing_pipeline( sagemaker_session, role, @@ -796,3 +996,46 @@ def test_two_processing_job_depends_on( pipeline.delete() except Exception: pass + + +def _verify_code_artifacts_of_framework_processing_step( + pipeline_session, processor, bucket, pipeline_name, step_definition, source_dir, entry_point +): + + source_dir_s3_uri = ( + f"s3://{bucket}/{pipeline_name}" f"/code/{hash_files_or_dirs([f'{DATA_DIR}/{source_dir}'])}" + ) + + # verify runproc.sh prefix is different from code artifact prefix + runprocs = [] + for input_obj in step_definition["Arguments"]["ProcessingInputs"]: + if input_obj["InputName"] == "entrypoint": + s3_uri = input_obj["S3Input"]["S3Uri"] + runprocs.append(s3_uri) + + assert Path(s3_uri).parent != source_dir_s3_uri + + # verify only one entrypoint generated per step + assert len(runprocs) == 1 + + expected_source_dir_tar = ( + f"{pipeline_name}" + f"/code/{hash_files_or_dirs([DATA_DIR + '/pipeline/test_source_dir'])}/sourcedir.tar.gz" + ) + + step_script = processor._generate_framework_script(entry_point) + expected_step_artifact = f"{pipeline_name}/code/{hash_object(step_script)}/runproc.sh" + + expected_prefix = f"{pipeline_name}/code" + s3_code_objects = pipeline_session.list_s3_files(bucket=bucket, key_prefix=expected_prefix) + + # verify all distinct artifacts were uploaded + assert expected_source_dir_tar in s3_code_objects + assert expected_step_artifact in s3_code_objects + + # verify runprocs contain the correct commands + step_runproc = S3Downloader.read_file( + f"s3://{bucket}/{expected_step_artifact}", pipeline_session + ) + assert f"python {entry_point}" in step_runproc + return source_dir, expected_step_artifact diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 634ef752d6..44f4e2d26e 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1168,7 +1168,13 @@ def walk(): def test_caching_behavior( - pipeline_session, role, cpu_instance_type, pipeline_name, script_dir, athena_dataset_definition + pipeline_session, + role, + cpu_instance_type, + pipeline_name, + script_dir, + athena_dataset_definition, + region_name, ): default_bucket = pipeline_session.default_bucket() data_path = os.path.join(DATA_DIR, "workflow") From 11d24754b0a8228893f6663ac1ca5048b8a6e794 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:54 +0000 Subject: [PATCH 038/526] prepare release v2.121.0 --- CHANGELOG.md | 11 +++++++++++ VERSION | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71894ff29d..29dad5f19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v2.121.0 (2022-12-08) + +### Features + + * Algorithms Region Expansion OSU/DXB + +### Bug Fixes and Other Changes + + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + ## v2.120.0 (2022-12-07) ### Features diff --git a/VERSION b/VERSION index 73c4cd6968..7f1e14b5a9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.1.dev0 +2.121.0 From 24171b5efcb9c528f159334d6252835ef10bbcb2 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:55 +0000 Subject: [PATCH 039/526] update development version to v2.121.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7f1e14b5a9..28b52ee8d5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.0 +2.121.1.dev0 From d5847d5ebad840c5f47204742302d91064904be8 Mon Sep 17 00:00:00 2001 From: Loki Date: Fri, 9 Dec 2022 03:10:14 +0530 Subject: [PATCH 040/526] Fix: Differentiate SageMaker Training Compiler's PT DLCs from base PT DLC (#3515) --- src/sagemaker/image_uri_config/pytorch-training-compiler.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json index 892ff4237d..fd7df875a3 100644 --- a/src/sagemaker/image_uri_config/pytorch-training-compiler.json +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -34,7 +34,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "pytorch-training" + "repository": "pytorch-trcomp-training" } } } From 3f6ea884a564090f826fab46270429db553c7b3b Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:17:44 -0500 Subject: [PATCH 041/526] fix: Fix failing jumpstart cache unit tests (#3514) --- setup.py | 2 +- src/sagemaker/jumpstart/cache.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4327045760..f366b147b8 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging>=20.0", + "packaging==20.9", "pandas", "pathos", "schema", diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 202edff9ad..db607770a7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -20,7 +20,7 @@ import boto3 import botocore from packaging.version import Version -from packaging.specifiers import SpecifierSet +from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -371,7 +371,10 @@ def _select_version( return None return str(max(available_versions)) - spec = SpecifierSet(f"=={semantic_version_str}") + try: + spec = SpecifierSet(f"=={semantic_version_str}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {semantic_version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None From 4570aa6078e75ba0d259f8196891b7856790a435 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Thu, 8 Dec 2022 19:00:48 -0800 Subject: [PATCH 042/526] fix: Pop out ModelPackageName from pipeline definition (#3472) Co-authored-by: Dewen Qi --- src/sagemaker/workflow/_utils.py | 12 ++ .../sagemaker/workflow/test_model_steps.py | 1 + tests/unit/sagemaker/workflow/conftest.py | 75 +++++++++ .../sagemaker/workflow/test_model_step.py | 147 +++++++----------- tests/unit/sagemaker/workflow/test_utils.py | 54 +------ 5 files changed, 150 insertions(+), 139 deletions(-) create mode 100644 tests/unit/sagemaker/workflow/conftest.py diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8ba65f1eee..cdef9537c1 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -13,6 +13,7 @@ """Scrapper utilities to support repacking of models.""" from __future__ import absolute_import +import logging import os import shutil import tarfile @@ -37,6 +38,8 @@ if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection +logger = logging.getLogger(__name__) + FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" @@ -479,10 +482,19 @@ def arguments(self) -> RequestType: request_dict = get_create_model_package_request(**model_package_args) # these are not available in the workflow service and will cause rejection + warn_msg_template = ( + "Popping out '%s' from the pipeline definition " + "since it will be overridden in pipeline execution time." + ) if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") + logger.warning(warn_msg_template, "CertifyForMarketplace") if "Description" in request_dict: request_dict.pop("Description") + logger.warning(warn_msg_template, "Description") + if "ModelPackageName" in request_dict: + request_dict.pop("ModelPackageName") + logger.warning(warn_msg_template, "ModelPackageName") return request_dict diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 31c518b100..f25723c440 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen inference_instances=["ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], description="test-description", + model_package_name="model-pkg-name-will-be-popped-out", ) step_model_regis = ModelStep( name="pytorch-register-model", diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py new file mode 100644 index 0000000000..9ea3d0bcac --- /dev/null +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock, PropertyMock + +import pytest + +from sagemaker import Session +from sagemaker.workflow.pipeline_context import PipelineSession + +REGION = "us-west-2" +BUCKET = "my-bucket" +ROLE = "DummyRole" +IMAGE_URI = "fakeimage" + + +@pytest.fixture(scope="module") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="module") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="module") +def pipeline_session(boto_session, client): + return PipelineSession( + boto_session=boto_session, + sagemaker_client=client, + default_bucket=BUCKET, + ) + + +@pytest.fixture(scope="module") +def sagemaker_session(boto_session, client): + return Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=BUCKET, + ) diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 080e70ca62..2216299d3b 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -15,7 +15,7 @@ import json import os -from mock import Mock, PropertyMock, patch +from mock import patch import pytest @@ -43,7 +43,6 @@ ) from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.pipeline import Pipeline, PipelineGraph -from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.retry import ( StepRetryPolicy, StepExceptionTypeEnum, @@ -55,11 +54,9 @@ from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE _IMAGE_URI = "fakeimage" -_REGION = "us-west-2" -_BUCKET = "my-bucket" -_ROLE = "DummyRole" _INSTANCE_TYPE = "ml.m4.xlarge" _SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper() @@ -69,60 +66,10 @@ _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") _REPACK_OUTPUT_KEY_PREFIX = "code-output" -_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" +_MODEL_CODE_LOCATION = f"s3://{BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" _MODEL_CODE_LOCATION_TRAILING_SLASH = _MODEL_CODE_LOCATION + "/" -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def boto_session(client): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - session_mock.client.return_value = client - - return session_mock - - -@pytest.fixture -def pipeline_session(boto_session, client): - return PipelineSession( - boto_session=boto_session, - sagemaker_client=client, - default_bucket=_BUCKET, - ) - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) - - @pytest.fixture def model_data_param(): return ParameterString(name="ModelData", default_value="s3://my-bucket/file") @@ -137,7 +84,7 @@ def model(pipeline_session, model_data_param): sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) @@ -322,13 +269,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_ sparkml_model = SparkMLModel( name="MySparkMLModel", model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) # The model need to runtime repack ppl_model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = ppl_model.create( instance_type="c4.4xlarge", @@ -417,7 +364,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat # The model no need to runtime repack, since source_dir is missing sparkml_model = SparkMLModel( model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", @@ -429,11 +376,11 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, env={"k": "v"}, ) model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = model.register( content_types=["text/csv"], @@ -516,7 +463,7 @@ def test_register_model_without_repack(pipeline_session): model_data=model_data, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", sagemaker_session=pipeline_session, - role=_ROLE, + role=ROLE, ) step_args = model.register( content_types=["text/csv"], @@ -547,7 +494,7 @@ def test_register_model_without_repack(pipeline_session): assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert ( containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] - == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz" + == f"s3://{BUCKET}/{model_name}/sourcedir.tar.gz" ) adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []}) @@ -560,11 +507,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): model = Model( name=model_name, image_uri=_IMAGE_URI, - model_data=f"s3://{_BUCKET}/model.tar.gz", + model_data=f"s3://{BUCKET}/model.tar.gz", sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) step_args = model.create( instance_type="c4.4xlarge", @@ -582,7 +529,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): arguments = step_dsl_list[0]["Arguments"] assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI assert ( - arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{_BUCKET}/{model_name}/model.tar.gz" + arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz" ) assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME @@ -700,7 +647,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=True, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -713,7 +660,7 @@ def test_conditional_model_create_and_regis( framework_version="1.11.0", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=False, ), 1, @@ -724,7 +671,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.5.0", code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -736,7 +683,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.2.0", ), 1, @@ -747,7 +694,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 2, ), @@ -757,7 +704,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), 2, @@ -768,7 +715,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 1, ), @@ -789,7 +736,7 @@ def assert_test_result(steps: list): ) else: assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( - f"s3://{_BUCKET}/{model.name}" + f"s3://{BUCKET}/{model.name}" ) model, expected_step_num = test_input @@ -828,7 +775,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), enable_network_isolation=True, ), @@ -845,7 +792,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), ), { @@ -861,7 +808,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=None, ), { @@ -876,9 +823,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), ), { @@ -893,9 +839,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, ), { "expected_step_num": 1, @@ -941,7 +886,7 @@ def test_request_compare_of_register_model_under_different_sessions( _verify_register_model_container_definition(regis_step_arg, expect, dict) # Get create model package request under Session - model.model_data = f"s3://{_BUCKET}" + model.model_data = f"s3://{BUCKET}" model.sagemaker_session = sagemaker_session with patch.object( Session, "_intercept_create_request", return_value=dict(ModelPackageArn="arn:aws") @@ -996,7 +941,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): model_data=lambda_step.properties.Outputs["model_artifact"], sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ) step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) @@ -1031,7 +976,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( Processor( image_uri=_IMAGE_URI, - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, ), @@ -1052,7 +997,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( HyperparameterTuner( estimator=Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1064,7 +1009,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ), ( Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1128,3 +1073,31 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model): ), ) assert "SageMakerJobStepRetryPolicy is not allowed for a create/registe" in str(error.value) + + +def test_register_model_step_with_model_package_name(pipeline_session): + model = Model( + name="MyModel", + image_uri="my-image", + model_data="s3://", + sagemaker_session=pipeline_session, + ) + step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_name="model-pkg-name-will-be-popped-out", + ) + regis_model_step = ModelStep( + name="MyModelStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[regis_model_step], + sagemaker_session=pipeline_session, + ) + steps = json.loads(pipeline.definition())["Steps"] + assert len(steps) == 1 + assert "ModelPackageName" not in steps[0]["Arguments"] diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index dcbf5a6421..c8d86c5866 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -18,12 +18,6 @@ import tempfile import pytest -import sagemaker - -from mock import ( - Mock, - PropertyMock, -) from sagemaker.estimator import Estimator from sagemaker.workflow._utils import ( @@ -35,51 +29,7 @@ from sagemaker.workflow.properties import Properties from tests.unit.test_utils import FakeS3, list_tar_files from tests.unit import DATA_DIR - -REGION = "us-west-2" -BUCKET = "my-bucket" -IMAGE_URI = "fakeimage" -ROLE = "DummyRole" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) +from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET @pytest.fixture @@ -171,7 +121,7 @@ def test_repack_model_step(estimator): } -def test_repack_model_step_with_invalid_input(): +def test_register_model_step_with_invalid_input(): # without both step_args and any of the old required arguments with pytest.raises(ValueError) as error: _RegisterModelStep( From 959ea1a485db702f361ddebda2e80779bfd20e43 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:46 +0000 Subject: [PATCH 043/526] prepare release v2.121.1 --- CHANGELOG.md | 7 +++++++ VERSION | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29dad5f19f..472a25feb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v2.121.1 (2022-12-09) + +### Bug Fixes and Other Changes + + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + ## v2.121.0 (2022-12-08) ### Features diff --git a/VERSION b/VERSION index 28b52ee8d5..f73c7f057e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1.dev0 +2.121.1 From b2e8b66016c09a3898123725bf1c01d1a87b05d0 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:47 +0000 Subject: [PATCH 044/526] update development version to v2.121.2.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index f73c7f057e..d866b235cc 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1 +2.121.2.dev0 From 355975d4d2d45088eeb13681f8d99e48a00909c9 Mon Sep 17 00:00:00 2001 From: amzn-choeric <105388439+amzn-choeric@users.noreply.github.com> Date: Fri, 9 Dec 2022 13:53:28 -0500 Subject: [PATCH 045/526] fix: Skip Bad Transform Test (#3521) --- tests/integ/test_inference_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 53d966fe9b..a26d8c9101 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -50,6 +50,7 @@ ) +@pytest.mark.skip(reason="Test has likely been failing for a while. Suspected bad XGB model.") def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type): sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"), From fadc817c7557f5fea5e414d51b500a6b7cd02065 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Fri, 9 Dec 2022 12:07:32 -0800 Subject: [PATCH 046/526] fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524) From c5fc93feea798df1713db6707737a2f24738c4c7 Mon Sep 17 00:00:00 2001 From: hballuru <113142824+hballuru@users.noreply.github.com> Date: Fri, 9 Dec 2022 16:36:12 -0600 Subject: [PATCH 047/526] change: Update for Tensorflow Serving 2.11 inference DLCs (#3509) --- .../image_uri_config/tensorflow.json | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index a0f2bba014..aaca927ba4 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -303,7 +303,8 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0" + "2.10": "2.10.0", + "2.11": "2.11.0" }, "versions": { "1.10.0": { @@ -1611,6 +1612,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1618,8 +1620,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1642,6 +1646,41 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.11.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1649,8 +1688,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", From ec8da98a9a7cae848e8bf1af06bdaaabd1ebb382 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:18:58 +0000 Subject: [PATCH 048/526] prepare release v2.121.2 --- CHANGELOG.md | 8 ++++++++ VERSION | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 472a25feb8..8b66e85f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.121.2 (2022-12-12) + +### Bug Fixes and Other Changes + + * Update for Tensorflow Serving 2.11 inference DLCs + * Revert "fix: type hint of PySparkProcessor __init__" + * Skip Bad Transform Test + ## v2.121.1 (2022-12-09) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index d866b235cc..3b02379cd3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2.dev0 +2.121.2 From 03521222d324ed752174038309828ed8183c5aea Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:19:00 +0000 Subject: [PATCH 049/526] update development version to v2.121.3.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 3b02379cd3..8fde5e282f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2 +2.121.3.dev0 From d6c021404586d4df601a6115add87fcbf75b6d65 Mon Sep 17 00:00:00 2001 From: Kristopher Siman Date: Mon, 12 Dec 2022 17:21:49 -0500 Subject: [PATCH 050/526] feature: Add OSU region to frameworks for DLC (#3532) --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../image_uri_config/huggingface-neuron.json | 1 + .../image_uri_config/huggingface.json | 31 ++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ .../image_uri_config/pytorch-neuron.json | 1 + src/sagemaker/image_uri_config/pytorch.json | 31 ++++++++ .../image_uri_config/tensorflow.json | 70 +++++++++++++++++++ 7 files changed, 159 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3a9f02142c..590b6e5f82 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -30,6 +30,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -61,6 +62,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -92,6 +94,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -123,6 +126,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -154,6 +158,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -185,6 +190,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -230,6 +236,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -267,6 +274,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -304,6 +312,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -341,6 +350,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -378,6 +388,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -415,6 +426,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 47d6dbd1dc..980dceed17 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -33,6 +33,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 5b98fc0d02..a0caa59a55 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -42,6 +42,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -75,6 +76,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -114,6 +116,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -147,6 +150,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -188,6 +192,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -222,6 +227,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -256,6 +262,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -290,6 +297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -332,6 +340,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -366,6 +375,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -400,6 +410,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -434,6 +445,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -474,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -508,6 +521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -548,6 +562,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -582,6 +597,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -622,6 +638,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -656,6 +673,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -712,6 +730,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -749,6 +768,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -786,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -831,6 +852,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -868,6 +890,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -905,6 +928,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -942,6 +966,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -985,6 +1010,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1022,6 +1048,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1065,6 +1092,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1102,6 +1130,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1145,6 +1174,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1182,6 +1212,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 8d8733e480..588a03a76e 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -249,6 +249,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -282,6 +283,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -315,6 +317,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -348,6 +351,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -381,6 +385,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -644,6 +649,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -680,6 +686,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -716,6 +723,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -752,6 +760,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -788,6 +797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +907,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -933,6 +944,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -969,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch-neuron.json b/src/sagemaker/image_uri_config/pytorch-neuron.json index b116a8a36b..5b29406955 100644 --- a/src/sagemaker/image_uri_config/pytorch-neuron.json +++ b/src/sagemaker/image_uri_config/pytorch-neuron.json @@ -28,6 +28,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 18a382e591..85681a3423 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -208,6 +208,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -247,6 +248,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -285,6 +287,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -323,6 +326,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -362,6 +366,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -401,6 +406,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -440,6 +446,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -479,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -517,6 +525,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -555,6 +564,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -593,6 +603,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -631,6 +642,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -669,6 +681,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -707,6 +720,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -744,6 +758,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -791,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -951,6 +967,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -987,6 +1004,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1023,6 +1041,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1058,6 +1077,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1094,6 +1114,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1130,6 +1151,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1166,6 +1188,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1202,6 +1225,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1237,6 +1261,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1272,6 +1297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1307,6 +1333,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1342,6 +1369,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1377,6 +1405,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1412,6 +1441,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1446,6 +1476,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index aaca927ba4..a900aa4fe5 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -161,6 +161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -196,6 +197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -231,6 +233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -266,6 +269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -425,6 +429,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -460,6 +465,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -495,6 +501,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -530,6 +537,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -565,6 +573,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -600,6 +609,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -635,6 +645,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -862,6 +873,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +909,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -932,6 +945,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -967,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1002,6 +1017,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1037,6 +1053,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1072,6 +1089,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1107,6 +1125,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1142,6 +1161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1177,6 +1197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1212,6 +1233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1247,6 +1269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1282,6 +1305,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1317,6 +1341,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1352,6 +1377,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1387,6 +1413,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1422,6 +1449,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1457,6 +1485,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1492,6 +1521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1527,6 +1557,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1562,6 +1593,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1597,6 +1629,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1631,6 +1664,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1665,6 +1699,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1699,6 +1734,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1746,6 +1782,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1940,6 +1977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1977,6 +2015,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2013,6 +2052,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2050,6 +2090,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2087,6 +2128,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2124,6 +2166,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2161,6 +2204,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2389,6 +2433,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2425,6 +2470,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2461,6 +2507,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2496,6 +2543,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2531,6 +2579,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2567,6 +2616,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2603,6 +2653,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2638,6 +2689,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2673,6 +2725,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2708,6 +2761,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2743,6 +2797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2778,6 +2833,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2813,6 +2869,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2848,6 +2905,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2883,6 +2941,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2918,6 +2977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2953,6 +3013,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2988,6 +3049,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3023,6 +3085,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3058,6 +3121,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3093,6 +3157,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3128,6 +3193,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3163,6 +3229,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3198,6 +3265,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3233,6 +3301,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3267,6 +3336,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", From 5af4feb57d950358dcf5dd15aad7f7d59ae11b31 Mon Sep 17 00:00:00 2001 From: Xiaoguang Chen <68292680+xgchena@users.noreply.github.com> Date: Mon, 12 Dec 2022 15:59:33 -0800 Subject: [PATCH 051/526] fix: Remove content type image/jpg from analysis configuration schema (#3530) Currently the analysis configuration schema of SageMaker Clarify API allows the content_type configuration "image/jpeg" and "image/jpg", but the service side validation only accepts the former which is the registered MIME type for JPEG (see rfc3745 and JPEG specification). The commit removes the latter from the schema to avoid confusion and enable early API validation. --- src/sagemaker/clarify.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4765630ce8..f082679401 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -282,7 +282,6 @@ "text/csv", "application/jsonlines", "image/jpeg", - "image/jpg", "image/png", "application/x-npy", ), From 438984754a8f44b34d70154197a3bbeb0272f052 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Mon, 12 Dec 2022 22:37:35 -0500 Subject: [PATCH 052/526] fix: unpin packaging version (#3533) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f366b147b8..4327045760 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging==20.9", + "packaging>=20.0", "pandas", "pathos", "schema", From a3efddf6d6a4e89861f2ae1eca9d7fd7712a691b Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Tue, 13 Dec 2022 20:45:06 +0100 Subject: [PATCH 053/526] fix: the Hyperband support fix for the HPO (#3516) Co-authored-by: Anton Repushko --- src/sagemaker/session.py | 9 +++++++ src/sagemaker/tuner.py | 14 +++++------ tests/unit/test_session.py | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..3fc4fc1256 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2121,6 +2121,7 @@ def tune( # noqa: C901 stop_condition, tags, warm_start_config, + strategy_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, @@ -2136,6 +2137,8 @@ def tune( # noqa: C901 Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. + strategy_config (dict): A configuration for the hyperparameter tuning + job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. @@ -2220,6 +2223,7 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, @@ -2375,6 +2379,7 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2392,6 +2397,8 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + strategy_config (dict): A configuration for the hyperparameter tuning job optimisation + strategy. Returns: A dictionary of tuning job configuration. For format details, please refer to @@ -2415,6 +2422,8 @@ def _map_tuning_config( if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges + if strategy_config is not None: + tuning_config["StrategyConfig"] = strategy_config return tuning_config @classmethod diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 52b9d81d0d..9a694cbec9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config): Returns: sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of - HyperbandStrategyConfig containing the max_resource and min_resource provided as part of - ``hyperband_strategy_config``. + ``HyperbandStrategyConfig`` containing the max_resource + and min_resource provided as part of ``hyperband_strategy_config``. """ return cls( min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE], @@ -306,7 +306,7 @@ def to_input_req(self): Returns: dict: Containing the "MaxResource" and - "MinResource" as the first class fields. + "MinResource" as the first class fields. """ return { HYPERBAND_MIN_RESOURCE: self.min_resource, @@ -330,7 +330,7 @@ def __init__( Args: hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration - for the object that specifies the Hyperband strategy. + for the object that specifies the Hyperband strategy. This parameter is only supported for the Hyperband selection for Strategy within the HyperParameterTuningJobConfig. """ @@ -461,7 +461,7 @@ def __init__( ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start tuning job. strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter" - tuning job optimisation strategy. + tuning job optimisation strategy. early_stopping_type (str or PipelineVariable): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. @@ -1569,7 +1569,7 @@ def create( strategy (str): Strategy to be used for hyperparameter estimations (default: 'Bayesian'). strategy_config (dict): The configuration for a training job launched by a - hyperparameter tuning job. + hyperparameter tuning job. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). max_jobs (int): Maximum total number of training jobs to start for the hyperparameter @@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs): } if tuner.strategy_config is not None: - tuning_config["strategy_config"] = tuner.strategy_config + tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() if tuner.objective_metric_name is not None: tuning_config["objective_type"] = tuner.objective_type diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..bf81283177 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session): ], } +SAMPLE_HYPERBAND_STRATEGY_CONFIG = { + "HyperbandStrategyConfig": { + "MinResource": 1, + "MaxResource": 10, + } +} + @pytest.mark.parametrize( "warm_start_type, parents", @@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags): ) +def test_tune_with_strategy_config(sagemaker_session): + def assert_create_tuning_job_request(**kwrags): + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MinResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"] + ) + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MaxResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"] + ) + + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image_uri="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG, + ) + + def test_tune_with_encryption_flag(sagemaker_session): def assert_create_tuning_job_request(**kwrags): assert ( From bd96ec5c585217bdec31951d632247f4b0d9f91b Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Tue, 13 Dec 2022 16:06:08 -0800 Subject: [PATCH 054/526] feature: Feature Store dataset builder, delete_record, get_record, list_feature_group (#3534) Co-authored-by: Eric Zou Co-authored-by: Yiming Zou Co-authored-by: Brandon Chatham Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> --- .../feature_store/dataset_builder.py | 990 ++++++++++++++++++ src/sagemaker/feature_store/feature_group.py | 45 +- src/sagemaker/feature_store/feature_store.py | 130 +++ src/sagemaker/session.py | 94 +- tests/integ/test_feature_store.py | 400 +++++++ .../feature_store/test_dataset_builder.py | 612 +++++++++++ .../feature_store/test_feature_group.py | 580 ++++++++++ .../feature_store/test_feature_store.py | 687 ++---------- tests/unit/test_session.py | 29 + 9 files changed, 2979 insertions(+), 588 deletions(-) create mode 100644 src/sagemaker/feature_store/dataset_builder.py create mode 100644 src/sagemaker/feature_store/feature_store.py create mode 100644 tests/unit/sagemaker/feature_store/test_dataset_builder.py create mode 100644 tests/unit/sagemaker/feature_store/test_feature_group.py diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py new file mode 100644 index 0000000000..fc82997379 --- /dev/null +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -0,0 +1,990 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Dataset Builder + +A Dataset Builder is a builder class for generating a dataset by providing conditions. +""" +from __future__ import absolute_import + +import datetime +from enum import Enum +import os +from typing import Any, Dict, List, Tuple, Union + +import attr +import pandas as pd + +from sagemaker import Session, s3, utils +from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum + + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + + +@attr.s +class TableType(Enum): + """Enum of Table types. + + The data type of a table can be FeatureGroup or DataFrame. + """ + + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +@attr.s +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the sql join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_dentifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + """ + + features: List[str] = attr.ib() + included_feature_names: List[str] = attr.ib() + projected_feature_names: List[str] = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table_name: str = attr.ib() + record_identifier_feature_name: str = attr.ib() + event_time_identifier_feature: FeatureDefinition = attr.ib() + target_feature_name_in_base: str = attr.ib(default=None) + table_type: TableType = attr.ib(default=None) + + +def construct_feature_group_to_be_merged( + feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + Returns: + A FeatureGroupToBeMerged object. + + Raises: + ValueError: Invalid feature name(s) in included_feature_names. + """ + feature_group_metadata = feature_group.describe() + data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get( + "DataCatalogConfig", None + ) + if not data_catalog_config: + raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.") + + record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None) + feature_definitions = feature_group_metadata.get("FeatureDefinitions", []) + event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None) + event_time_identifier_feature_type = FeatureTypeEnum( + next( + filter( + lambda f: f.get("FeatureName", None) == event_time_identifier_feature_name, + feature_definitions, + ), + {}, + ).get("FeatureType", None) + ) + table_name = data_catalog_config.get("TableName", None) + database = data_catalog_config.get("Database", None) + disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False) + catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG + features = [feature.get("FeatureName", None) for feature in feature_definitions] + + for included_feature in included_feature_names or []: + if included_feature not in features: + raise ValueError( + f"Feature {included_feature} not found in FeatureGroup {feature_group.name}" + ) + if not included_feature_names: + included_feature_names = features + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_identifier_feature_name not in included_feature_names: + included_feature_names.append(record_identifier_feature_name) + if event_time_identifier_feature_name not in included_feature_names: + included_feature_names.append(event_time_identifier_feature_name) + return FeatureGroupToBeMerged( + features, + included_feature_names, + projected_feature_names, + catalog, + database, + table_name, + record_identifier_feature_name, + FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type), + target_feature_name_in_base, + TableType.FEATURE_GROUP, + ) + + +@attr.s +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + _kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing whether using point in time join + or not (default: False). + _include_duplicated_records (bool): A boolean representing whether including duplicated + records or not (default: False). + _include_deleted_records (bool): A boolean representing whether including deleted records or + not (default: False). + _number_of_recent_records (int): An int that how many records will be returned for each + record identifier (default: 1). + _number_of_records (int): An int that how many records will be returned (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in + dataset will be before it (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time + in dataset will be after it (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be before it (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session = attr.ib() + _base: Union[FeatureGroup, pd.DataFrame] = attr.ib() + _output_path: str = attr.ib() + _record_identifier_feature_name: str = attr.ib(default=None) + _event_time_identifier_feature_name: str = attr.ib(default=None) + _included_feature_names: List[str] = attr.ib(default=None) + _kms_key_id: str = attr.ib(default=None) + + _point_in_time_accurate_join: bool = attr.ib(init=False, default=False) + _include_duplicated_records: bool = attr.ib(init=False, default=False) + _include_deleted_records: bool = attr.ib(init=False, default=False) + _number_of_recent_records: int = attr.ib(init=False, default=None) + _number_of_records: int = attr.ib(init=False, default=None) + _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list) + _event_time_identifier_feature_type: FeatureTypeEnum = attr.ib(default=None) + + _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = { + "object": "STRING", + "int64": "INT", + "float64": "DOUBLE", + "bool": "BOOLEAN", + "datetime64[ns]": "TIMESTAMP", + } + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + ): + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base + ) + ) + return self + + def point_in_time_accurate_join(self): + """Set join type as point in time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self): + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self): + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int): + """Set number_of_recent_records field with provided input. + + Args: + number_of_recent_records (int): An int that how many recent records will be returned for + each record identifier. + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = number_of_recent_records + return self + + def with_number_of_records_from_query_results(self, number_of_records: int): + """Set number_of_records field with provided input. + + Args: + number_of_records (int): An int that how many records will be returned. + Returns: + This DatasetBuilder object. + """ + self._number_of_records = number_of_records + return self + + def as_of(self, timestamp: datetime.datetime): + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ): + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> Tuple[str, str]: + """Get query string and result in .csv format file + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + temp_id = utils.unique_name_from_base("dataframe-base") + local_file_name = f"{temp_id}.csv" + desired_s3_folder = f"{self._output_path}/{temp_id}" + self._base.to_csv(local_file_name, index=False, header=False) + s3.S3Uploader.upload( + local_path=local_file_name, + desired_s3_uri=desired_s3_folder, + sagemaker_session=self._sagemaker_session, + kms_key=self._kms_key_id, + ) + os.remove(local_file_name) + temp_table_name = f'dataframe_{temp_id.replace("-", "_")}' + self._create_temp_table(temp_table_name, desired_s3_folder) + base_features = list(self._base.columns) + event_time_identifier_feature_dtype = self._base[ + self._event_time_identifier_feature_name + ].dtypes + self._event_time_identifier_feature_type = ( + FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + str(event_time_identifier_feature_dtype), None + ) + ) + query_string = self._construct_query_string( + FeatureGroupToBeMerged( + base_features, + self._included_feature_names if self._included_feature_names else base_features, + self._included_feature_names if self._included_feature_names else base_features, + _DEFAULT_CATALOG, + _DEFAULT_DATABASE, + temp_table_name, + self._record_identifier_feature_name, + FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + None, + TableType.DATA_FRAME, + ) + ) + query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + # TODO: cleanup temp table, need more clarification, keep it for now + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + if isinstance(self._base, FeatureGroup): + base_feature_group = construct_feature_group_to_be_merged( + self._base, self._included_feature_names + ) + self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name + self._event_time_identifier_feature_name = ( + base_feature_group.event_time_identifier_feature.feature_name + ) + self._event_time_identifier_feature_type = ( + base_feature_group.event_time_identifier_feature.feature_type + ) + query_string = self._construct_query_string(base_feature_group) + query_result = self._run_query( + query_string, + base_feature_group.catalog, + base_feature_group.database, + ) + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> Tuple[pd.DataFrame, str]: + """Get query string and result in pandas.Dataframe + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + s3.S3Downloader.download( + s3_uri=csv_file, + local_path="./", + kms_key=self._kms_key_id, + sagemaker_session=self._sagemaker_session, + ) + local_file_name = csv_file.split("/")[-1] + df = pd.read_csv(local_file_name) + os.remove(local_file_name) + + local_metadata_file_name = local_file_name + ".metadata" + if os.path.exists(local_metadata_file_name): + os.remove(local_file_name + ".metadata") + + if "row_recent" in df: + df = df.drop("row_recent", axis="columns") + return df, query_string + + def _construct_event_time_conditions( + self, + table_name: str, + event_time_identifier_feature: FeatureDefinition, + ) -> List[str]: + """Internal method for constructing event time range sql range as string. + + Args: + table_name (str): name of the table. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + Returns: + The list of query strings. + """ + event_time_conditions = [] + timestamp_cast_function_name = "from_unixtime" + if event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._event_time_starting_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") >= ' + + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") <= ' + + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return event_time_conditions + + def _construct_write_time_condition( + self, + table_name: str, + ) -> str: + """Internal method for constructing write time condition. + + Args: + table_name (str): name of the table. + Returns: + string of write time condition. + """ + write_time_condition = ( + f'{table_name}."write_time" <= ' + f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', " + f"'yyyy-mm-dd hh24:mi:ss')" + ) + return write_time_condition + + def _construct_where_query_string( + self, + suffix: str, + event_time_identifier_feature: FeatureDefinition, + where_conditions: List[str], + ) -> str: + """Internal method for constructing SQL WHERE query string by parameters. + + Args: + suffix (str): A temp identifier of the FeatureGroup. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + where_conditions (List[str]): A list of strings representing existing where clauses. + Returns: + The WHERE query string. + + Raises: + ValueError: FeatureGroup not provided while using as_of(). Only found pandas.DataFrame. + """ + if self._number_of_recent_records: + if self._number_of_recent_records < 0: + raise ValueError( + "Please provide non-negative integer for number_of_recent_records." + ) + if self._number_of_records: + if self._number_of_records < 0: + raise ValueError("Please provide non-negative integer for number_of_records.") + if self._include_deleted_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_deleted_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._include_duplicated_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_duplicated_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._point_in_time_accurate_join: + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "point_in_time_accurate_join() this operation only works when there is " + "more than one feature group to join." + ) + if self._write_time_ending_timestamp: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "as_of() only works for FeatureGroup," " if there is no join operation." + ) + if isinstance(self._base, FeatureGroup): + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"table_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) == 0: + return "" + return "WHERE " + "\nAND ".join(where_conditions) + + def _construct_dedup_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing duplicate records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = "" + where_conditions = [] + where_conditions_str = "" + is_dedup_enabled = False + + if feature_group.table_type is TableType.FEATURE_GROUP: + is_dedup_enabled = True + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"origin_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) != 0: + where_conditions_str = "WHERE " + "\nAND ".join(where_conditions) + "\n" + + dedup_where_clause = f"WHERE dedup_row_{suffix} = 1\n" if is_dedup_enabled else "" + return ( + f"table_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS dedup_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + where_conditions_str + + ")\n" + + dedup_where_clause + + ")" + ) + + def _construct_deleted_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing deleted records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = f'ORDER BY origin_{suffix}."{event_time_feature_name}" DESC' + write_time_condition = "\n" + event_time_starting_condition = "" + event_time_ending_condition = "" + + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string += ( + f', origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + write_time_condition += " AND " + write_time_condition += self._construct_write_time_condition(f"origin_{suffix}") + write_time_condition += "\n" + + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + event_time_starting_condition = "AND " + event_time_conditions[0] + "\n" + event_time_ending_condition = "AND " + event_time_conditions[1] + "\n" + + return ( + f"deleted_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}"\n' + + rank_query_string + + f") AS deleted_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE is_deleted" + + write_time_condition + + event_time_starting_condition + + event_time_ending_condition + + ")\n" + + f"WHERE deleted_row_{suffix} = 1\n" + + ")" + ) + + def _construct_table_included_features( + self, feature_group: FeatureGroupToBeMerged, suffix: str + ) -> str: + """Internal method for constructing included features string of table. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object + which has the metadata. + suffix (str): A temp identifier of the table. + Returns: + The string that includes all feature to be included of table. + """ + + included_features = ", ".join( + [ + f'table_{suffix}."{include_feature_name}"' + for include_feature_name in feature_group.included_feature_names + ] + ) + return included_features + + def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The query string. + """ + included_features = self._construct_table_included_features(feature_group, suffix) + + # If base is a FeatureGroup then included_features_write_time will have a write_time column + # Or included_features_write_time is same as included_features + included_features_write_time = included_features + + if feature_group.table_type is TableType.FEATURE_GROUP: + included_features_write_time += f', table_{suffix}."write_time"' + record_feature_name = feature_group.record_identifier_feature_name + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included_features}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n' + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, ["NOT is_deleted"] + ) + ) + if feature_group.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + return ( + f"SELECT {included_features}\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE NOT is_deleted" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, + feature_group.event_time_identifier_feature, + [f"row_{suffix} = 1"], + ) + ) + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" > ' + + f'deleted_{suffix}."api_invocation_time")\n' + + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" = ' + + f'deleted_{suffix}."api_invocation_time" ' + + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + final_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + if self._include_duplicated_records: + final_query_string = ( + f"WITH {self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f'JOIN "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)},\n" + + f"{self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f"JOIN table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + return final_query_string + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata. + Returns: + The query string. + + Raises: + ValueError: target_feature_name_in_base is an invalid feature name. + """ + base_table_query_string = self._construct_table_query(base, "base") + query_string = f"WITH fg_base AS ({base_table_query_string})" + if len(self._feature_groups_to_be_merged) > 0: + with_subquery_string = "".join( + [ + f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})" + for i, feature_group in enumerate(self._feature_groups_to_be_merged) + ] + ) + query_string += with_subquery_string + + selected_features = "" + selected_features += ", ".join(map("fg_base.{0}".format, base.projected_feature_names)) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features += ", " + selected_features += ", ".join( + [ + f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + for feature_name in feature_group.projected_feature_names + ] + ) + + selected_features_final = "" + selected_features_final += ", ".join(base.projected_feature_names) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features_final += ", " + selected_features_final += ", ".join( + [ + '"{0}.{1}"'.format(feature_name, (i + 1)) + for feature_name in feature_group.projected_feature_names + ] + ) + + query_string += ( + f"\nSELECT {selected_features_final}\n" + + "FROM (\n" + + f"SELECT {selected_features}, row_number() OVER (\n" + + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + recent_record_where_clause = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_record_where_clause = f"WHERE row_recent <= {self._number_of_recent_records}" + + join_subquery_strings = [] + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + if not feature_group.target_feature_name_in_base: + feature_group.target_feature_name_in_base = self._record_identifier_feature_name + else: + if feature_group.target_feature_name_in_base not in base.features: + raise ValueError( + f"Feature {feature_group.target_feature_name_in_base} not found in base" + ) + query_string += ( + f', fg_{i}."{feature_group.event_time_identifier_feature.feature_name}" DESC' + ) + join_subquery_strings.append(self._construct_join_condition(feature_group, str(i))) + + query_string += ( + "\n) AS row_recent\n" + + "FROM fg_base" + + "".join(join_subquery_strings) + + "\n)\n" + + f"{recent_record_where_clause}" + ) + + if self._number_of_records is not None and self._number_of_records >= 0: + query_string += f"\nLIMIT {self._number_of_records}" + return query_string + + def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL JOIN query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The JOIN query string. + """ + join_condition_string = ( + f"\nJOIN fg_{suffix}\n" + + f'ON fg_base."{feature_group.target_feature_name_in_base}" = ' + + f'fg_{suffix}."{feature_group.record_identifier_feature_name}"' + ) + base_timestamp_cast_function_name = "from_unixtime" + if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING: + base_timestamp_cast_function_name = "from_iso8601_timestamp" + timestamp_cast_function_name = "from_unixtime" + if feature_group.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._point_in_time_accurate_join: + join_condition_string += ( + f"\nAND {base_timestamp_cast_function_name}(fg_base." + + f'"{self._event_time_identifier_feature_name}") >= ' + + f"{timestamp_cast_function_name}(fg_{suffix}." + + f'"{feature_group.event_time_identifier_feature.feature_name}")' + ) + return join_condition_string + + def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str): + """Internal method for creating a temp Athena table for the base pandas.Dataframe. + + Args: + temp_table_name (str): The Athena table name of base pandas.DataFrame. + desired_s3_folder (str): The S3 URI of the folder of the data. + """ + columns_string = ", ".join( + [self._construct_athena_table_column_string(column) for column in self._base.columns] + ) + serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query_string = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + f"WITH SERDEPROPERTIES ({serde_properties}) " + + f"LOCATION '{desired_s3_folder}';" + ) + self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + def _construct_athena_table_column_string(self, column: str) -> str: + """Internal method for constructing string of Athena column. + + Args: + column (str): The column name from pandas.Dataframe. + Returns: + The Athena column string. + + Raises: + RuntimeError: The type of pandas.Dataframe column is not support yet. + """ + dataframe_type = self._base[column].dtypes + if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys(): + raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.") + return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}" + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + """Internal method for execute Athena query, wait for query finish and get query result. + + Args: + query_string (str): The SQL query statements to be executed. + catalog (str): The name of the data catalog used in the query execution. + database (str): The name of the database used in the query execution. + Returns: + The query result. + + Raises: + RuntimeError: Athena query failed. + """ + query = self._sagemaker_session.start_query_execution( + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + query_id = query.get("QueryExecutionId", None) + self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id) + query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id) + query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None) + + if query_state != "SUCCEEDED": + raise RuntimeError(f"Failed to execute query {query_id}.") + return query_result diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index d486ab8a01..855e11488f 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -435,13 +435,14 @@ class FeatureGroup: "uint64", ] _FLOAT_TYPES = ["float_", "float16", "float32", "float64"] - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES } - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( {type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES} ) - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING _FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { FeatureTypeEnum.INTEGRAL.value: "INT", @@ -629,7 +630,7 @@ def load_feature_definitions( """ feature_definitions = [] for column in data_frame: - feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( str(data_frame[column].dtype), None ) if feature_type: @@ -644,6 +645,23 @@ def load_feature_definitions( self.feature_definitions = feature_definitions return self.feature_definitions + def get_record( + self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None + ) -> Sequence[Dict[str, str]]: + """Get a single record in a FeatureGroup + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + feature_names (Sequence[String]): + a list of Strings representing feature names. + """ + return self.sagemaker_session.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_group_name=self.name, + feature_names=feature_names, + ).get("Record") + def put_record(self, record: Sequence[FeatureValue]): """Put a single record in the FeatureGroup. @@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]): feature_group_name=self.name, record=[value.to_dict() for value in record] ) + def delete_record( + self, + record_identifier_value_as_string: str, + event_time: str, + ): + """Delete a single record from a FeatureGroup. + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + event_time (String): + a timestamp format String indicating when the deletion event occurred. + """ + return self.sagemaker_session.delete_record( + feature_group_name=self.name, + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + def ingest( self, data_frame: DataFrame, diff --git a/src/sagemaker/feature_store/feature_store.py b/src/sagemaker/feature_store/feature_store.py new file mode 100644 index 0000000000..def8b2b2da --- /dev/null +++ b/src/sagemaker/feature_store/feature_store.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Store. + +Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and +manage features for machine learning (ML) models. +""" +from __future__ import absolute_import + +import datetime +from typing import Any, Dict, Sequence, Union + +import attr +import pandas as pd + +from sagemaker import Session +from sagemaker.feature_store.dataset_builder import DatasetBuilder +from sagemaker.feature_store.feature_group import FeatureGroup + + +@attr.s +class FeatureStore: + """FeatureStore definition. + + This class instantiates a FeatureStore object that comprises a SageMaker session instance. + + Attributes: + sagemaker_session (Session): session instance to perform boto calls. + """ + + sagemaker_session: Session = attr.ib(default=Session) + + def create_dataset( + self, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, + ) -> DatasetBuilder: + """Create a Dataset Builder for generating a Dataset. + + Args: + base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a + Dataset. + output_path (str): An S3 URI which stores the output .csv file. + record_identifier_feature_name (str): A string representing the record identifier + feature if base is a DataFrame (default: None). + event_time_identifier_feature_name (str): A string representing the event time + identifier feature if base is a DataFrame (default: None). + included_feature_names (List[str]): A list of features to be included in the output + (default: None). + kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + + Raises: + ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event + time identifier feature name is provided. + """ + if isinstance(base, pd.DataFrame): + if record_identifier_feature_name is None or event_time_identifier_feature_name is None: + raise ValueError( + "You must provide a record identifier feature name and an event time " + + "identifier feature name if specify DataFrame as base." + ) + return DatasetBuilder( + self.sagemaker_session, + base, + output_path, + record_identifier_feature_name, + event_time_identifier_feature_name, + included_feature_names, + kms_key_id, + ) + + def list_feature_groups( + self, + name_contains: str = None, + feature_group_status_equals: str = None, + offline_store_status_equals: str = None, + creation_time_after: datetime.datetime = None, + creation_time_before: datetime.datetime = None, + sort_order: str = None, + sort_by: str = None, + max_results: int = None, + next_token: str = None, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + return self.sagemaker_session.list_feature_groups( + name_contains=name_contains, + feature_group_status_equals=feature_group_status_equals, + offline_store_status_equals=offline_store_status_equals, + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + sort_order=sort_order, + sort_by=sort_by, + max_results=max_results, + next_token=next_token, + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3fc4fc1256..72df570496 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -312,7 +312,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. for key in keys: - tail_s3_uri_path = os.path.basename(key_prefix) + tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) @@ -4341,6 +4341,56 @@ def update_feature_group( FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions ) + def list_feature_groups( + self, + name_contains, + feature_group_status_equals, + offline_store_status_equals, + creation_time_after, + creation_time_before, + sort_order, + sort_by, + max_results, + next_token, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + list_feature_groups_args = {} + + def check_object(key, value): + if value is not None: + list_feature_groups_args[key] = value + + check_object("NameContains", name_contains) + check_object("FeatureGroupStatusEquals", feature_group_status_equals) + check_object("OfflineStoreStatusEquals", offline_store_status_equals) + check_object("CreationTimeAfter", creation_time_after) + check_object("CreationTimeBefore", creation_time_before) + check_object("SortOrder", sort_order) + check_object("SortBy", sort_by) + check_object("MaxResults", max_results) + check_object("NextToken", next_token) + + return self.sagemaker_client.list_feature_groups(**list_feature_groups_args) + def update_feature_metadata( self, feature_group_name: str, @@ -4408,6 +4458,48 @@ def put_record( Record=record, ) + def delete_record( + self, + feature_group_name: str, + record_identifier_value_as_string: str, + event_time: str, + ): + """Deletes a single record from the FeatureGroup. + + Args: + feature_group_name (str): name of the FeatureGroup. + record_identifier_value_as_string (str): name of the record identifier. + event_time (str): a timestamp indicating when the deletion event occurred. + """ + return self.sagemaker_featurestore_runtime_client.delete_record( + FeatureGroupName=feature_group_name, + RecordIdentifierValueAsString=record_identifier_value_as_string, + EventTime=event_time, + ) + + def get_record( + self, + record_identifier_value_as_string: str, + feature_group_name: str, + feature_names: Sequence[str], + ) -> Dict[str, Sequence[Dict[str, str]]]: + """Gets a single record in the FeatureGroup. + + Args: + record_identifier_value_as_string (str): name of the record identifier. + feature_group_name (str): name of the FeatureGroup. + feature_names (Sequence[str]): list of feature names. + """ + get_record_args = { + "FeatureGroupName": feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + } + + if feature_names: + get_record_args["FeatureNames"] = feature_names + + return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args) + def start_query_execution( self, catalog: str, diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index c1b84117c3..e19cebdca4 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -14,6 +14,7 @@ import json import time +import datetime from contextlib import contextmanager import boto3 @@ -24,6 +25,7 @@ from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.feature_store.feature_store import FeatureStore from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout @@ -80,6 +82,11 @@ def feature_group_name(): return f"my-feature-group-{int(time.time() * 10**7)}" +@pytest.fixture +def base_name(): + return f"my-base-{int(time.time() * 10**7)}" + + @pytest.fixture def offline_store_s3_uri(feature_store_session, region_name): bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}" @@ -107,6 +114,32 @@ def pandas_data_frame(): return df +@pytest.fixture +def base_dataframe(): + base_data = [ + [1, 187512346.0, 123, 128], + [2, 187512347.0, 168, 258], + [3, 187512348.0, 125, 184], + [1, 187512349.0, 195, 206], + ] + return pd.DataFrame( + base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"] + ) + + +@pytest.fixture +def feature_group_dataframe(): + feature_group_data = [ + [1, 187512246.0, 456, 325], + [2, 187512247.0, 729, 693], + [3, 187512348.0, 129, 901], + [1, 187512449.0, 289, 286], + ] + return pd.DataFrame( + feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"] + ) + + @pytest.fixture def pandas_data_frame_without_string(): df = pd.DataFrame( @@ -288,6 +321,92 @@ def test_create_feature_group_glue_table_format( assert table_format == "Glue" +def test_get_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + record_names = list(map(lambda r: r.feature_name, record)) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + removed_feature_name = record_names.pop() + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=record_names, + ) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + assert feature["FeatureName"] is not removed_feature_name + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string="1.0", + ) + assert retrieved_record is None + + +def test_delete_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is not None + # Delete data + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=datetime.datetime.now().replace(microsecond=0).isoformat() + "Z", + ) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is None + + def test_update_feature_group( feature_store_session, role, @@ -316,6 +435,25 @@ def test_update_feature_group( assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) +def test_list_feature_groups(feature_store_session, role, feature_group_name, pandas_data_frame): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + output = feature_store.list_feature_groups(name_contains=feature_group_name) + + assert output["FeatureGroupSummaries"][0]["FeatureGroupName"] == feature_group_name + + def test_feature_metadata( feature_store_session, role, @@ -420,6 +558,242 @@ def test_ingest_multi_process( assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") +def test_create_dataset_with_feature_group_base( + feature_store_session, + region_name, + role, + base_name, + feature_group_name, + offline_store_s3_uri, + base_dataframe, + feature_group_dataframe, +): + base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + with cleanup_feature_group(base), cleanup_feature_group(feature_group): + _create_feature_group_and_ingest_data( + base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role + ) + _create_feature_group_and_ingest_data( + feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role + ) + base_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, base, offline_store_s3_uri + ) + feature_group_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, feature_group, offline_store_s3_uri + ) + + with timeout(minutes=10) and cleanup_offline_store( + base_table_name, feature_store_session + ) and cleanup_offline_store(feature_group_table_name, feature_store_session): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + df, query_string = ( + feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) + .with_number_of_recent_records_by_record_identifier(4) + .with_feature_group(feature_group) + .to_dataframe() + ) + sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True) + merged_df = base_dataframe.merge( + feature_group_dataframe, left_on="base_id", right_on="fg_id" + ) + + expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True) + + expect_df.rename( + columns={ + "fg_id": "fg_id.1", + "fg_time": "fg_time.1", + "fg_feature_1": "fg_feature_1.1", + "fg_feature_2": "fg_feature_2.1", + }, + inplace=True, + ) + + assert sorted_df.equals(expect_df) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id"\n' + + 'ORDER BY origin_base."base_time" DESC,' + ' origin_base."api_invocation_time" DESC,' + ' origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2"\n' + + "FROM (\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + 'WHERE deleted_base."base_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + "AND (\n" + + 'table_base."base_time" > deleted_base."base_time"\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" >' + ' deleted_base."api_invocation_time")\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" =' + ' deleted_base."api_invocation_time" AND' + ' table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id"\n' + + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,' + ' origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",' + ' table_0."fg_feature_2"\n' + + "FROM (\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + 'WHERE deleted_0."fg_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + "AND (\n" + + 'table_0."fg_time" > deleted_0."fg_time"\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" >' + ' deleted_0."api_invocation_time")\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" =' + ' deleted_0."api_invocation_time" AND table_0."write_time" >' + ' deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + "SELECT base_id, base_time, base_feature_1, base_feature_2," + ' "fg_id.1", "fg_time.1", "fg_feature_1.1",' + ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time," + " fg_base.base_feature_1, fg_base.base_feature_2," + ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",' + ' fg_0."fg_feature_1" as "fg_feature_1.1",' + ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()' + " OVER (\n" + + 'PARTITION BY fg_base."base_id"\n' + + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."base_id" = fg_0."fg_id"\n' + + ")\n" + + "WHERE row_recent <= 4" + ) + + +def _create_feature_group_and_ingest_data( + feature_group: FeatureGroup, + dataframe: DataFrame, + offline_store_s3_uri: str, + record_identifier_name: str, + event_time_name: str, + role: str, +): + feature_group.load_feature_definitions(data_frame=dataframe) + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name=record_identifier_name, + event_time_feature_name=event_time_name, + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False) + ingestion_manager.wait() + assert 0 == len(ingestion_manager.failed_rows) + + +def _get_athena_table_name_after_data_replication( + feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri +): + feature_group_metadata = feature_group.describe() + resolved_output_s3_uri = ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("S3StorageConfig", None) + .get("ResolvedOutputS3Uri", None) + ) + s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "") + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + while True: + objects_in_bucket = s3_client.list_objects( + Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix + ) + if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1: + break + else: + print(f"Waiting for {feature_group.name} data in offline store...") + time.sleep(60) + print(f"{feature_group.name} data available.") + return ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("DataCatalogConfig", None) + .get("TableName", None) + ) + + def _wait_for_feature_group_create(feature_group: FeatureGroup): status = feature_group.describe().get("FeatureGroupStatus") while status == "Creating": @@ -451,5 +825,31 @@ def cleanup_feature_group(feature_group: FeatureGroup): finally: try: feature_group.delete() + print(f"{feature_group.name} is deleted") except Exception: raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") + + +@contextmanager +def cleanup_offline_store(table_name: str, feature_store_session: Session): + try: + yield + finally: + try: + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + account_id = feature_store_session.account_id() + bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}" + response = s3_client.list_objects_v2( + Bucket=bucket_name, + Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/", + ) + files_in_folder = response["Contents"] + files_to_delete = [] + for f in files_in_folder: + files_to_delete.append({"Key": f["Key"]}) + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete}) + except Exception: + raise RuntimeError(f"Failed to delete data under {table_name}") diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..0e55b86bd0 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py @@ -0,0 +1,612 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime + +import pandas as pd +import pytest +import os +from mock import Mock, patch + +from sagemaker.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, +) +from sagemaker.feature_store.feature_group import ( + FeatureDefinition, + FeatureGroup, + FeatureTypeEnum, +) + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def feature_group_mock(): + return Mock() + + +@pytest.fixture +def read_csv_mock(): + return Mock() + + +@pytest.fixture +def to_csv_file_mock(): + return Mock() + + +@pytest.fixture +def remove_mock(): + return Mock() + + +BASE = FeatureGroupToBeMerged( + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + "catalog", + "database", + "base-table", + "target-feature", + FeatureDefinition("other-feature", FeatureTypeEnum.STRING), + None, + TableType.FEATURE_GROUP, +) +FEATURE_GROUP = FeatureGroupToBeMerged( + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + "catalog", + "database", + "table-name", + "feature-1", + FeatureDefinition("feature-2", FeatureTypeEnum.FRACTIONAL), + "target-feature", + TableType.FEATURE_GROUP, +) + + +def test_with_feature_group_throw_runtime_error(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}} + with pytest.raises(RuntimeError) as error: + dataset_builder.with_feature_group( + feature_group, "target-feature", ["feature-1", "feature-2"] + ) + assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error) + + +def test_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + feature_group.load_feature_definitions(dataframe) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + record_identifier_feature_name="target-feature", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"]) + assert len(dataset_builder._feature_groups_to_be_merged) == 1 + assert dataset_builder._feature_groups_to_be_merged[0].features == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].database == "database" + assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table" + assert ( + dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name + == "feature-1" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_name + == "feature-2" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_type + == FeatureTypeEnum.STRING + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base + == "target-feature" + ) + + +def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.point_in_time_accurate_join() + assert dataset_builder._point_in_time_accurate_join + + +def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_duplicated_records() + assert dataset_builder._include_duplicated_records + + +def test_include_deleted_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_deleted_records() + assert dataset_builder._include_deleted_records + + +def test_with_number_of_recent_records_by_record_identifier( + sagemaker_session_mock, feature_group_mock +): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_recent_records_by_record_identifier(5) + assert dataset_builder._number_of_recent_records == 5 + + +def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_records_from_query_results(100) + assert dataset_builder._number_of_records == 100 + + +def test_with_event_time_range(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + start = datetime.datetime.now() + end = start + datetime.timedelta(minutes=1) + dataset_builder.with_event_time_range(start, end) + assert dataset_builder._event_time_starting_timestamp == start + assert dataset_builder._event_time_ending_timestamp == end + + +def test_to_csv_file_not_support_base_type(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + with pytest.raises(ValueError) as error: + dataset_builder.to_csv_file() + assert "Base must be either a FeatureGroup or a DataFrame." in str(error) + + +def test_to_csv_file_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3-file-path"}, + "Query": "query-string", + } + } + file_path, query_string = dataset_builder.to_csv_file() + assert file_path == "s3-file-path" + assert query_string == "query-string" + + +@patch("pandas.DataFrame.to_csv") +@patch("pandas.read_csv") +@patch("os.remove") +def test_to_dataframe_with_dataframe( + remove_mock, read_csv_mock, to_csv_file_mock, sagemaker_session_mock +): + dataframe = pd.DataFrame({"feature-1": [420, 380.0, 390], "feature-2": [50, 40.0, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="s3://file/to/path", + event_time_identifier_feature_name="feature-2", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3://s3-file-path"}, + "Query": "query-string", + } + } + to_csv_file_mock.return_value = None + read_csv_mock.return_value = dataframe + os.remove.return_value = None + df, query_string = dataset_builder.to_dataframe() + assert df.equals(dataframe) + assert query_string == "query-string" + + +def test_construct_where_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + time = datetime.datetime.now().replace(microsecond=0) + start = time + datetime.timedelta(minutes=1) + end = start + datetime.timedelta(minutes=1) + dataset_builder._write_time_ending_timestamp = time + dataset_builder._event_time_starting_timestamp = start + dataset_builder._event_time_ending_timestamp = end + query_string = dataset_builder._construct_where_query_string( + "suffix", + FeatureDefinition("event-time", FeatureTypeEnum.STRING), + ["NOT is_deleted"], + ) + assert ( + query_string + == "WHERE NOT is_deleted\n" + + f"AND table_suffix.\"write_time\" <= to_timestamp('{time}', " + + "'yyyy-mm-dd hh24:mi:ss')\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") >= ' + + f"from_unixtime({start.timestamp()})\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") <= ' + + f"from_unixtime({end.timestamp()})" + ) + + +def test_construct_query_string_with_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder._include_duplicated_records = True + + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" DESC, ' + + 'origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + 'FROM "database"."base-table" table_base\n' + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + 'JOIN "database"."base-table" table_base\n' + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + 'FROM "database"."table-name" table_0\n' + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + 'JOIN "database"."table-name" table_0\n' + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" > ' + + 'deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" = ' + + 'deleted_0."api_invocation_time" AND table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + ")\n" + ) + + +def test_construct_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + dataset_builder._point_in_time_accurate_join = True + dataset_builder._event_time_identifier_feature_name = "target-feature" + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" ' + + 'DESC, origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND ' + + 'table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + 'AND from_unixtime(fg_base."target-feature") >= from_unixtime(fg_0."feature-2")\n' + + ")\n" + ) + + +def test_create_temp_table(sagemaker_session_mock): + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + dataset_builder._create_temp_table("table-name", "s3-folder") + assert sagemaker_session_mock.start_query_execution.call_count == 1 + sagemaker_session_mock.start_query_execution.assert_called_once_with( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + query_string="CREATE EXTERNAL TABLE table-name (feature-1 INT, feature-2 INT) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + 'WITH SERDEPROPERTIES ("separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\") ' + + "LOCATION 's3-folder';", + output_location="file/to/path", + kms_key=None, + ) + + +@pytest.mark.parametrize( + "column, expected", + [ + ("feature-1", "feature-1 STRING"), + ("feature-2", "feature-2 INT"), + ("feature-3", "feature-3 DOUBLE"), + ("feature-4", "feature-4 BOOLEAN"), + ("feature-5", "feature-5 TIMESTAMP"), + ], +) +def test_construct_athena_table_column_string(column, expected, sagemaker_session_mock): + dataframe = pd.DataFrame( + { + "feature-1": ["420"], + "feature-2": [50], + "feature-3": [5.0], + "feature-4": [True], + "feature-5": [pd.Timestamp(1513393355)], + } + ) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + query_string = dataset_builder._construct_athena_table_column_string(column) + assert query_string == expected + + +def test_construct_athena_table_column_string_not_support_column_type( + sagemaker_session_mock, +): + dataframe = pd.DataFrame({"feature": pd.Series([1] * 3, dtype="int8")}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + with pytest.raises(RuntimeError) as error: + dataset_builder._construct_athena_table_column_string("feature") + assert "The dataframe type int8 is not supported yet." in str(error) + + +def test_run_query_throw_runtime_error(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + with pytest.raises(RuntimeError) as error: + dataset_builder._run_query("query-string", "catalog", "database") + assert "Failed to execute query query-id." in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py new file mode 100644 index 0000000000..dce38fe426 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -0,0 +1,580 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pandas as pd +import pytest +from mock import Mock, patch, MagicMock +from botocore.exceptions import ProfileNotFound + +from sagemaker.feature_store.feature_definition import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + FeatureTypeEnum, +) +from sagemaker.feature_store.feature_group import ( + FeatureGroup, + IngestionManagerPandas, + AthenaQuery, + IngestionError, +) +from sagemaker.feature_store.inputs import FeatureParameter + + +class PicklableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +@pytest.fixture +def s3_uri(): + return "s3://some/uri" + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def fs_runtime_client_config_mock(): + return PicklableMock() + + +@pytest.fixture +def feature_group_dummy_definitions(): + return [ + FractionalFeatureDefinition(feature_name="feature1"), + IntegralFeatureDefinition(feature_name="feature2"), + StringFeatureDefinition(feature_name="feature3"), + ] + + +@pytest.fixture +def create_table_ddl(): + return ( + "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + " feature1 FLOAT\n" + " feature2 INT\n" + " feature3 STRING\n" + " write_time TIMESTAMP\n" + " event_time TIMESTAMP\n" + " is_deleted BOOLEAN\n" + ")\n" + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + "LOCATION 's3://resolved_output_s3_uri'" + ) + + +def test_feature_store_create( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + offline_store_config={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": s3_uri}, + }, + ) + + +def test_feature_store_create_online_only( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + ) + + +def test_feature_store_delete(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.delete() + sagemaker_session_mock.delete_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup" + ) + + +def test_feature_store_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe() + sagemaker_session_mock.describe_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", next_token=None + ) + + +def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(feature_group_dummy_definitions) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + ) + + +def test_feature_metadata_update(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + parameter_additions = [FeatureParameter(key="key1", value="value1")] + parameter_removals = ["key2"] + + feature_group.update_feature_metadata( + feature_name="Feature1", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[pa.to_dict() for pa in parameter_additions], + parameter_removals=parameter_removals, + ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[], + parameter_removals=[], + ) + + +def test_feature_metadata_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe_feature_metadata(feature_name="Feature1") + sagemaker_session_mock.describe_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", feature_name="Feature1" + ) + + +def test_get_record(sagemaker_session_mock): + feature_group_name = "MyFeatureGroup" + feature_names = ["MyFeature1", "MyFeature2"] + record_identifier_value_as_string = "1.0" + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session_mock) + feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + sagemaker_session_mock.get_record.assert_called_with( + feature_group_name=feature_group_name, + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + + +def test_put_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.put_record(record=[]) + sagemaker_session_mock.put_record.assert_called_with( + feature_group_name="MyFeatureGroup", record=[] + ) + + +def test_delete_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + record_identifier_value_as_string = "1.0" + event_time = "2022-09-14" + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + sagemaker_session_mock.delete_record.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + + +def test_load_feature_definition(sagemaker_session_mock): + feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "string": pd.Series(["f1"], dtype="string"), + } + ) + feature_definitions = feature_group.load_feature_definitions(data_frame=df) + names = [fd.feature_name for fd in feature_definitions] + types = [fd.feature_type for fd in feature_definitions] + assert names == ["float", "int", "string"] + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + ] + + +def test_load_feature_definition_unsupported_types(sagemaker_session_mock): + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "bool": pd.Series([True], dtype="bool"), + } + ) + with pytest.raises(ValueError) as error: + feature_group.load_feature_definitions(data_frame=df) + assert "Failed to infer Feature type based on dtype bool for column bool." in str(error) + + +def test_ingest_zero_processes(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) + + assert "max_processes must be greater than 0." in str(error) + + +def test_ingest_zero_workers(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) + + assert "max_workers must be greater than 0." in str(error) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10) + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name=None, + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest_with_profile_name( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name="profile_name", + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +def test_as_hive_ddl_with_default_values( + create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock +): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert ( + create_table_ddl.format( + database="sagemaker_featurestore", + table_name="MyGroup", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) + == feature_group.as_hive_ddl() + ) + + +def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert create_table_ddl.format( + database="MyDatabase", + table_name="MyTable", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", + MagicMock(), +) +def test_ingestion_manager_run_success(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + ) + manager.run(df) + + manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", + PicklableMock(return_value=[]), +) +def test_ingestion_manager_run_multi_process_with_multi_thread_success( + fs_runtime_client_config_mock, +): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=2, + max_processes=2, + ) + manager.run(df) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(return_value=[1]), +) +def test_ingestion_manager_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1] + assert manager.failed_rows == [1] + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(side_effect=ProfileNotFound(profile="non_exist")), +) +def test_ingestion_manager_with_profile_name_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + profile_name="non_exist", + ) + + try: + manager.run(df) + except Exception as e: + assert "The config profile (non_exist) could not be found" in str(e) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + PicklableMock(return_value=[1]), +) +def test_ingestion_manager_run_multi_process_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=None, + max_workers=2, + max_processes=2, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1, 1, 1, 1] + assert manager.failed_rows == [1, 1, 1, 1] + + +@pytest.fixture +def query(sagemaker_session_mock): + return AthenaQuery( + catalog="catalog", + database="database", + table_name="table_name", + sagemaker_session=sagemaker_session_mock, + ) + + +def test_athena_query_run(sagemaker_session_mock, query): + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} + query.run( + query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup" + ) + sagemaker_session_mock.start_query_execution.assert_called_with( + catalog="catalog", + database="database", + query_string="query", + output_location="s3://some-bucket/some-path", + kms_key=None, + workgroup="workgroup", + ) + assert "some-bucket" == query._result_bucket + assert "some-path" == query._result_file_prefix + assert "query_id" == query._current_query_execution_id + + +def test_athena_query_wait(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.wait() + sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") + + +def test_athena_query_get_query_execution(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.get_query_execution() + sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +@patch("pandas.read_csv") +def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + query._current_query_execution_id = "query_id" + query._result_bucket = "bucket" + query._result_file_prefix = "prefix" + query.as_dataframe() + sagemaker_session_mock.download_athena_query_result.assert_called_with( + bucket="bucket", + prefix="prefix", + query_execution_id="query_id", + filename="tmp/query_id.csv", + ) + read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Failed to execute query query_id" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "QUEUED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "RUNNING"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 92ba35573c..073daca9ea 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -10,46 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import datetime import pandas as pd import pytest -from mock import Mock, patch, MagicMock -from botocore.exceptions import ProfileNotFound - -from sagemaker.feature_store.feature_definition import ( - FractionalFeatureDefinition, - IntegralFeatureDefinition, - StringFeatureDefinition, - FeatureTypeEnum, -) -from sagemaker.feature_store.feature_group import ( - FeatureGroup, - IngestionManagerPandas, - AthenaQuery, - IngestionError, -) -from sagemaker.feature_store.inputs import ( - FeatureParameter, - TableFormatEnum, -) - +from mock import Mock -class PicklableMock(Mock): - def __reduce__(self): - return (Mock, ()) +from sagemaker.feature_store.feature_store import FeatureStore - -@pytest.fixture -def role_arn(): - return "arn:role" - - -@pytest.fixture -def s3_uri(): - return "s3://some/uri" +DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]}) @pytest.fixture @@ -58,558 +29,108 @@ def sagemaker_session_mock(): @pytest.fixture -def fs_runtime_client_config_mock(): - return PicklableMock() - - -@pytest.fixture -def feature_group_dummy_definitions(): - return [ - FractionalFeatureDefinition(feature_name="feature1"), - IntegralFeatureDefinition(feature_name="feature2"), - StringFeatureDefinition(feature_name="feature3"), - ] - - -@pytest.fixture -def create_table_ddl(): - return ( - "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" - " feature1 FLOAT\n" - " feature2 INT\n" - " feature3 STRING\n" - " write_time TIMESTAMP\n" - " event_time TIMESTAMP\n" - " is_deleted BOOLEAN\n" - ")\n" - "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" - " STORED AS\n" - " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" - " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" - "LOCATION 's3://resolved_output_s3_uri'" - ) - - -def test_feature_store_create( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_iceberg_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.ICEBERG, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Iceberg", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_glue_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.GLUE, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Glue", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_online_only( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=False, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - ) - - -def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.delete() - sagemaker_session_mock.delete_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup" - ) - - -def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe() - sagemaker_session_mock.describe_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", next_token=None - ) - - -def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.update(feature_group_dummy_definitions) - sagemaker_session_mock.update_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], - ) - - -def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - - parameter_additions = [FeatureParameter(key="key1", value="value1")] - parameter_removals = ["key2"] - - feature_group.update_feature_metadata( - feature_name="Feature1", - description="TestDescription", - parameter_additions=parameter_additions, - parameter_removals=parameter_removals, - ) - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[pa.to_dict() for pa in parameter_additions], - parameter_removals=parameter_removals, - ) - feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[], - parameter_removals=[], - ) - - -def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe_feature_metadata(feature_name="Feature1") - sagemaker_session_mock.describe_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", feature_name="Feature1" - ) - - -def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.put_record(record=[]) - sagemaker_session_mock.put_record.assert_called_with( - feature_group_name="MyFeatureGroup", record=[] - ) - - -def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "string": pd.Series(["f1"], dtype="string"), - } - ) - feature_definitions = feature_group.load_feature_definitions(data_frame=df) - names = [fd.feature_name for fd in feature_definitions] - types = [fd.feature_type for fd in feature_definitions] - assert names == ["float", "int", "string"] - assert types == [ - FeatureTypeEnum.FRACTIONAL, - FeatureTypeEnum.INTEGRAL, - FeatureTypeEnum.STRING, - ] +def feature_group_mock(): + return Mock() -def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "object": pd.Series(["f1"], dtype="object"), - } - ) +def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + output_path="file/to/path", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._output_path == "file/to/path" + + +def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=DATAFRAME, + record_identifier_feature_name="feature_1", + event_time_identifier_feature_name="feature_2", + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base.equals(DATAFRAME) + assert dataset_builder._record_identifier_feature_name == "feature_1" + assert dataset_builder._event_time_identifier_feature_name == "feature_2" + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) with pytest.raises(ValueError) as error: - feature_group.load_feature_definitions(data_frame=df) - assert "Failed to infer Feature type based on dtype object for column object." in str(error) - - -def test_ingest_zero_processes(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) - - assert "max_processes must be greater than 0." in str(error) - - -def test_ingest_zero_workers(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) - - assert "max_workers must be greater than 0." in str(error) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10) - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name=None, - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest_with_profile_name( - ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock -): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name="profile_name", - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -def test_as_hive_ddl_with_default_values( - create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock -): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert ( - create_table_ddl.format( - database="sagemaker_featurestore", - table_name="MyGroup", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", + feature_store.create_dataset( + base=DATAFRAME, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", ) - == feature_group.as_hive_ddl() - ) - - -def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert create_table_ddl.format( - database="MyDatabase", - table_name="MyTable", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", - ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", - MagicMock(), -) -def test_ingestion_manager_run_success(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - ) - manager.run(df) - - manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", - PicklableMock(return_value=[]), -) -def test_ingestion_manager_run_multi_process_with_multi_thread_success( - fs_runtime_client_config_mock, -): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=2, - max_processes=2, - ) - manager.run(df) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(return_value=[1]), -) -def test_ingestion_manager_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1] - assert manager.failed_rows == [1] - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(side_effect=ProfileNotFound(profile="non_exist")), -) -def test_ingestion_manager_with_profile_name_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - profile_name="non_exist", - ) - - try: - manager.run(df) - except Exception as e: - assert "The config profile (non_exist) could not be found" in str(e) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - PicklableMock(return_value=[1]), -) -def test_ingestion_manager_run_multi_process_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=None, - max_workers=2, - max_processes=2, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1, 1, 1, 1] - assert manager.failed_rows == [1, 1, 1, 1] - - -@pytest.fixture -def query(sagemaker_session_mock): - return AthenaQuery( - catalog="catalog", - database="database", - table_name="table_name", - sagemaker_session=sagemaker_session_mock, - ) - - -def test_athena_query_run(sagemaker_session_mock, query): - WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run( - query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP - ) - sagemaker_session_mock.start_query_execution.assert_called_with( - catalog="catalog", - database="database", - query_string="query", - output_location="s3://some-bucket/some-path", - kms_key=None, - workgroup=WORKGROUP, - ) - assert "some-bucket" == query._result_bucket - assert "some-path" == query._result_file_prefix - assert "query_id" == query._current_query_execution_id - - -def test_athena_query_wait(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") - - -def test_athena_query_get_query_execution(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -@patch("pandas.read_csv") -def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "SUCCEEDED"}} - } - query._current_query_execution_id = "query_id" - query._result_bucket = "bucket" - query._result_file_prefix = "prefix" - query.as_dataframe() - sagemaker_session_mock.download_athena_query_result.assert_called_with( - bucket="bucket", - prefix="prefix", - query_execution_id="query_id", - filename="tmp/query_id.csv", + assert ( + "You must provide a record identifier feature name and an event time identifier feature " + + "name if specify DataFrame as base." + in str(error) + ) + + +def test_list_feature_groups_with_no_filter(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups() + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ) + + +def test_list_feature_groups_with_all_filters(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", ) - read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "FAILED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Failed to execute query query_id" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "QUEUED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "RUNNING"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index bf81283177..d7c94470f5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2787,6 +2787,35 @@ def test_feature_metadata_describe(sagemaker_session): ) +def test_list_feature_groups(sagemaker_session): + expected_list_feature_groups_args = { + "NameContains": "MyFeatureGroup", + "FeatureGroupStatusEquals": "Created", + "OfflineStoreStatusEquals": "Active", + "CreationTimeAfter": datetime.datetime(2020, 12, 1), + "CreationTimeBefore": datetime.datetime(2022, 7, 1), + "SortOrder": "Ascending", + "SortBy": "Name", + "MaxResults": 50, + "NextToken": "token", + } + sagemaker_session.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + assert sagemaker_session.sagemaker_client.list_feature_groups.called_once() + assert sagemaker_session.sagemaker_client.list_feature_groups.called_with( + **expected_list_feature_groups_args + ) + + def test_start_query_execution(sagemaker_session): athena_mock = Mock() sagemaker_session.boto_session.client( From fb3880f804854d8456682c4aa17de321cb5a89f9 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:14 +0000 Subject: [PATCH 055/526] prepare release v2.122.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b66e85f54..de20a8a0df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.122.0 (2022-12-14) + +### Features + + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + +### Bug Fixes and Other Changes + + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + ## v2.121.2 (2022-12-12) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 8fde5e282f..202f672bab 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.3.dev0 +2.122.0 From a584ea5ff73ea5b6df8eec749069ec86adf2e8fc Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:15 +0000 Subject: [PATCH 056/526] update development version to v2.122.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 202f672bab..6d7f044fa2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.122.0 +2.122.1.dev0 From 93a846670f57f444b590551b2d67a3c6a95302aa Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 14 Dec 2022 09:09:46 -0800 Subject: [PATCH 057/526] feature: Add SageMaker Experiment (#3536) * feature: Add experiment plus Run class (#691) * feature: Add Experiment helper classes (#646) * feature: Add Experiment helper classes feature: Add helper class _RunEnvironment * change: Change sleep retry to backoff retry for get TC * minor fixes in backoff retry Co-authored-by: Dewen Qi * feature: Add helper classes and methods for Run class (#660) * feature: Add helper classes and methods for Run class * Add Parent class to address comment * fix docstyle check * Add arg docstrings in _helper Co-authored-by: Dewen Qi * feature: Add Experiment Run class (#651) Co-authored-by: Dewen Qi * change: Add integ tests for Run (#673) Co-authored-by: Dewen Qi * Update run log metric to use MetricsManager (#678) * Update run.log_metric to use _MetricsManager * fix several metrics issues * Add doc strings to metrics.py Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Dewen Qi Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> * change: Simplify exp plus integ test configuration (#694) Co-authored-by: Dewen Qi * feature: add RunName to expeirment_config (#696) * change: Update Run init and add Run load and _RunContext (#707) * change: Update Run init and add Run load Add exp name and run group name to load and address comments * Address nit comments Co-authored-by: Dewen Qi * fix: Fix run name uniqueness issue (#730) Co-authored-by: Dewen Qi * change: Update integ tests for Exp Plus M1 changes (#741) Co-authored-by: Dewen Qi * add metrics client to session object (#745) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * change: Add integ test for using Run in Transform Job (#749) Co-authored-by: Dewen Qi * Add async metrics sink (#739) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * use metrics client provided by session (#754) * fix flaky metrics test (#753) * change: Change Run.init and Run.load to constructor and module method respectively (#752) Co-authored-by: Dewen Qi * feature: Add latest metric service model (#757) Co-authored-by: Dewen Qi Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * fix: lowercase run name (#767) * Change: Minimize use of lower case tc name (#769) * change: Clean up test resources to remove model files (#756) * change: Clean up test resources to remove model files * fix: Change experiment enums to upper case * change: Upgrade boto3 and update test to validate mixed case name * fix: Update as per latest botocore release and backend change Co-authored-by: Dewen Qi * lowercase trial component name (#776) * change: Expose sagemaker experiment doc strings * fix: Fix exp name mixed case in issue Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Yifei Zhu <66866419+yzhu0@users.noreply.github.com> --- .gitignore | 5 +- doc/experiments/index.rst | 10 + doc/experiments/sagemaker.experiments.rst | 20 + doc/index.rst | 10 + requirements/extras/test_requirements.txt | 1 + setup.py | 2 +- src/sagemaker/amazon/amazon_estimator.py | 7 +- src/sagemaker/apiutils/_base_types.py | 6 +- src/sagemaker/apiutils/_boto_functions.py | 4 +- src/sagemaker/dataset_definition/inputs.py | 6 +- src/sagemaker/estimator.py | 16 +- src/sagemaker/experiments/__init__.py | 20 + src/sagemaker/experiments/_api_types.py | 251 +++++ src/sagemaker/experiments/_environment.py | 132 +++ src/sagemaker/experiments/_helper.py | 266 +++++ src/sagemaker/experiments/_metrics.py | 413 ++++++++ src/sagemaker/experiments/_run_context.py | 58 ++ src/sagemaker/experiments/_utils.py | 218 ++++ src/sagemaker/experiments/experiment.py | 237 +++++ src/sagemaker/experiments/run.py | 882 ++++++++++++++++ src/sagemaker/experiments/trial.py | 289 ++++++ src/sagemaker/experiments/trial_component.py | 341 +++++++ src/sagemaker/lineage/_utils.py | 17 - src/sagemaker/lineage/artifact.py | 3 +- src/sagemaker/processing.py | 9 +- src/sagemaker/session.py | 23 +- src/sagemaker/transformer.py | 7 +- src/sagemaker/utilities/search_expression.py | 133 +++ src/sagemaker/utils.py | 66 ++ tests/data/experiment/inference.py | 85 ++ .../process_job_script_for_run_clz.py | 37 + .../train_job_script_for_run_clz.py | 71 ++ .../transform_job_materials/data.csv | 1 + .../transform_job_materials/xgb_model.tar.gz | Bin 0 -> 35946 bytes tests/integ/sagemaker/experiments/__init__.py | 0 tests/integ/sagemaker/experiments/conftest.py | 177 ++++ tests/integ/sagemaker/experiments/helpers.py | 42 + .../sagemaker/experiments/test_experiment.py | 56 ++ .../sagemaker/experiments/test_metrics.py | 39 + tests/integ/sagemaker/experiments/test_run.py | 662 ++++++++++++ .../integ/sagemaker/experiments/test_trial.py | 75 ++ .../experiments/test_trial_component.py | 144 +++ tests/integ/sagemaker/lineage/conftest.py | 5 +- tests/integ/sagemaker/lineage/helpers.py | 14 - .../integ/sagemaker/lineage/test_artifact.py | 4 +- tests/integ/sagemaker/utilities/__init__.py | 0 .../utilities/test_search_expression.py | 67 ++ tests/integ/test_marketplace.py | 4 +- tests/integ/test_multidatamodel.py | 21 +- tests/integ/utils.py | 20 + tests/unit/conftest.py | 66 ++ tests/unit/sagemaker/experiments/__init__.py | 0 tests/unit/sagemaker/experiments/conftest.py | 86 ++ tests/unit/sagemaker/experiments/helpers.py | 44 + .../sagemaker/experiments/test_environment.py | 107 ++ .../sagemaker/experiments/test_experiment.py | 306 ++++++ .../unit/sagemaker/experiments/test_helper.py | 195 ++++ .../sagemaker/experiments/test_metrics.py | 178 ++++ tests/unit/sagemaker/experiments/test_run.py | 941 ++++++++++++++++++ .../sagemaker/experiments/test_run_context.py | 191 ++++ .../unit/sagemaker/experiments/test_trial.py | 276 +++++ .../experiments/test_trial_component.py | 384 +++++++ .../unit/sagemaker/experiments/test_utils.py | 36 + .../sagemaker/huggingface/test_estimator.py | 1 + .../sagemaker/tensorflow/test_estimator.py | 1 + .../test_huggingface_pytorch_compiler.py | 1 + .../test_huggingface_tensorflow_compiler.py | 1 + .../test_tensorflow_compiler.py | 1 + .../utilities/test_search_expression.py | 80 ++ .../workflow/test_clarify_check_step.py | 44 - .../unit/sagemaker/workflow/test_entities.py | 43 - .../workflow/test_quality_check_step.py | 46 - tests/unit/sagemaker/workflow/test_steps.py | 47 +- tests/unit/test_amazon_estimator.py | 13 +- tests/unit/test_estimator.py | 9 +- tests/unit/test_mxnet.py | 1 + tests/unit/test_pytorch.py | 1 + tests/unit/test_rl.py | 1 + tests/unit/test_session.py | 15 + tests/unit/test_sklearn.py | 1 + tests/unit/test_utils.py | 64 +- tests/unit/test_xgboost.py | 1 + 82 files changed, 7894 insertions(+), 263 deletions(-) create mode 100644 doc/experiments/index.rst create mode 100644 doc/experiments/sagemaker.experiments.rst create mode 100644 src/sagemaker/experiments/__init__.py create mode 100644 src/sagemaker/experiments/_api_types.py create mode 100644 src/sagemaker/experiments/_environment.py create mode 100644 src/sagemaker/experiments/_helper.py create mode 100644 src/sagemaker/experiments/_metrics.py create mode 100644 src/sagemaker/experiments/_run_context.py create mode 100644 src/sagemaker/experiments/_utils.py create mode 100644 src/sagemaker/experiments/experiment.py create mode 100644 src/sagemaker/experiments/run.py create mode 100644 src/sagemaker/experiments/trial.py create mode 100644 src/sagemaker/experiments/trial_component.py create mode 100644 src/sagemaker/utilities/search_expression.py create mode 100644 tests/data/experiment/inference.py create mode 100644 tests/data/experiment/process_job_script_for_run_clz.py create mode 100644 tests/data/experiment/train_job_script_for_run_clz.py create mode 100644 tests/data/experiment/transform_job_materials/data.csv create mode 100644 tests/data/experiment/transform_job_materials/xgb_model.tar.gz create mode 100644 tests/integ/sagemaker/experiments/__init__.py create mode 100644 tests/integ/sagemaker/experiments/conftest.py create mode 100644 tests/integ/sagemaker/experiments/helpers.py create mode 100644 tests/integ/sagemaker/experiments/test_experiment.py create mode 100644 tests/integ/sagemaker/experiments/test_metrics.py create mode 100644 tests/integ/sagemaker/experiments/test_run.py create mode 100644 tests/integ/sagemaker/experiments/test_trial.py create mode 100644 tests/integ/sagemaker/experiments/test_trial_component.py create mode 100644 tests/integ/sagemaker/utilities/__init__.py create mode 100644 tests/integ/sagemaker/utilities/test_search_expression.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/sagemaker/experiments/__init__.py create mode 100644 tests/unit/sagemaker/experiments/conftest.py create mode 100644 tests/unit/sagemaker/experiments/helpers.py create mode 100644 tests/unit/sagemaker/experiments/test_environment.py create mode 100644 tests/unit/sagemaker/experiments/test_experiment.py create mode 100644 tests/unit/sagemaker/experiments/test_helper.py create mode 100644 tests/unit/sagemaker/experiments/test_metrics.py create mode 100644 tests/unit/sagemaker/experiments/test_run.py create mode 100644 tests/unit/sagemaker/experiments/test_run_context.py create mode 100644 tests/unit/sagemaker/experiments/test_trial.py create mode 100644 tests/unit/sagemaker/experiments/test_trial_component.py create mode 100644 tests/unit/sagemaker/experiments/test_utils.py create mode 100644 tests/unit/sagemaker/utilities/test_search_expression.py diff --git a/.gitignore b/.gitignore index 9829ed9781..cae8f890ea 100644 --- a/.gitignore +++ b/.gitignore @@ -30,5 +30,6 @@ env/ .vscode/ **/tmp .python-version -**/_repack_model.py -**/_repack_script_launcher.sh \ No newline at end of file +**/_repack_script_launcher.sh +tests/data/**/_repack_model.py +tests/data/experiment/sagemaker-dev-1.0.tar.gz diff --git a/doc/experiments/index.rst b/doc/experiments/index.rst new file mode 100644 index 0000000000..8c12f30edc --- /dev/null +++ b/doc/experiments/index.rst @@ -0,0 +1,10 @@ +############################ +Amazon SageMaker Experiments +############################ + +The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally. + +.. toctree:: + :maxdepth: 2 + + sagemaker.experiments diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst new file mode 100644 index 0000000000..f0776ec43b --- /dev/null +++ b/doc/experiments/sagemaker.experiments.rst @@ -0,0 +1,20 @@ +Experiments +============ + +Run +------------- + +.. autoclass:: sagemaker.experiments.Run + :members: + +.. automethod:: sagemaker.experiments.load_run + +.. automethod:: sagemaker.experiments.list_runs + +.. autoclass:: sagemaker.experiments.SortByType + :members: + :undoc-members: + +.. autoclass:: sagemaker.experiments.SortOrderType + :members: + :undoc-members: diff --git a/doc/index.rst b/doc/index.rst index 2d4ebe32c1..69038056b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub workflows/index +**************************** +Amazon SageMaker Experiments +**************************** +You can use Amazon SageMaker Experiments to track machine learning experiments. + +.. toctree:: + :maxdepth: 2 + + experiments/index + ************************* Amazon SageMaker Debugger ************************* diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index fe93fd4d0e..494b6dca11 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -20,3 +20,4 @@ requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 pandas>=1.3.5,<1.5 +scikit-learn==1.0.2 diff --git a/setup.py b/setup.py index 4327045760..e2adb6b433 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=20.3.0,<23", - "boto3>=1.26.20,<2.0", + "boto3>=1.26.28,<2.0", "google-pasta", "numpy>=1.9.0,<2.0", "protobuf>=3.1,<4.0", diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index b156f2e65f..1abea5e48c 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -27,7 +27,7 @@ from sagemaker.deprecations import renamed_warning from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable @@ -242,8 +242,8 @@ def fit( generates a default job name, based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -255,6 +255,7 @@ def fit( """ self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new( self, records, experiment_config=experiment_config ) diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index e920797b18..9a7359e12b 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -173,8 +173,10 @@ def _search( search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: - if cls.__name__ in item: - yield search_item_factory(item[cls.__name__]) + # _TrialComponent class in experiments module is not public currently + class_name = cls.__name__.lstrip("_") + if class_name in item: + yield search_item_factory(item[class_name]) if not next_token: break except StopIteration: diff --git a/src/sagemaker/apiutils/_boto_functions.py b/src/sagemaker/apiutils/_boto_functions.py index 1e29f2ebea..a227d30ca8 100644 --- a/src/sagemaker/apiutils/_boto_functions.py +++ b/src/sagemaker/apiutils/_boto_functions.py @@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type): api_type, is_collection = member_name_to_type[member_name] if is_collection: if isinstance(boto_value, dict): - member_value = api_type.from_boto(boto_value) + member_value = { + key: api_type.from_boto(value) for key, value in boto_value.items() + } else: member_value = [api_type.from_boto(item) for item in boto_value] else: diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 90a272c4d7..468be22ac3 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject): """DatasetDefinition input.""" _custom_boto_types = { - "redshift_dataset_definition": (RedshiftDatasetDefinition, True), - "athena_dataset_definition": (AthenaDatasetDefinition, True), + # RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection + # Instead they are singleton objects. Thus, set the is_collection flag to False. + "redshift_dataset_definition": (RedshiftDatasetDefinition, False), + "athena_dataset_definition": (AthenaDatasetDefinition, False), } def __init__( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6f729267de..e3b06950aa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -79,6 +79,7 @@ get_config_value, name_from_base, to_string, + check_and_get_run_experiment_config, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -1103,8 +1104,8 @@ def fit( job_name (str): Training job name. If not specified, the estimator generates a default job name based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -1122,6 +1123,7 @@ def fit( """ self._prepare_for_training(job_name=job_name) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) if wait: @@ -2023,8 +2025,8 @@ def start_new(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2033,6 +2035,7 @@ def start_new(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started training job. @@ -2053,8 +2056,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2063,6 +2066,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: Dict: dict for `sagemaker.session.Session.train` method diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..b87656b1ab --- /dev/null +++ b/src/sagemaker/experiments/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker Experiment Module""" +from __future__ import absolute_import + +from sagemaker.experiments.run import Run # noqa: F401 +from sagemaker.experiments.run import load_run # noqa: F401 +from sagemaker.experiments.run import list_runs # noqa: F401 +from sagemaker.experiments.run import SortOrderType # noqa: F401 +from sagemaker.experiments.run import SortByType # noqa: F401 diff --git a/src/sagemaker/experiments/_api_types.py b/src/sagemaker/experiments/_api_types.py new file mode 100644 index 0000000000..78f82565aa --- /dev/null +++ b/src/sagemaker/experiments/_api_types.py @@ -0,0 +1,251 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains API objects for SageMaker experiments.""" +from __future__ import absolute_import + +import enum +import numbers + +from sagemaker.apiutils import _base_types + + +class TrialComponentMetricSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + metric_name (str): The name of the metric. + source_arn (str): The ARN of the source. + time_stamp (datetime): Metric last updated value. + max (float): The max value of the metric. + min (float): The min value of the metric. + last (float): The last value of the metric. + count (float): The number of samples used to generate the metric. + avg (float): The average value of the metric. + std_dev (float): The standard deviation of the metric. + """ + + metric_name = None + source_arn = None + time_stamp = None + max = None + min = None + last = None + count = None + avg = None + std_dev = None + + def __init__(self, metric_name=None, source_arn=None, **kwargs): + super(TrialComponentMetricSummary, self).__init__( + metric_name=metric_name, source_arn=source_arn, **kwargs + ) + + +class TrialComponentParameters(_base_types.ApiObject): + """A dictionary of TrialComponentParameterValues""" + + @classmethod + def from_boto(cls, boto_dict, **kwargs): + """Converts a boto dict to a dictionary of TrialComponentParameterValues + + Args: + boto_dict (dict): boto response dictionary. + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: Dictionary of parameter values. + """ + return_map = {} + for key, value in boto_dict.items(): + return_map[key] = value.get("NumberValue", value.get("StringValue", None)) + return return_map + + @classmethod + def to_boto(cls, parameters): + """Converts TrialComponentParameters to dict. + + Args: + parameters (TrialComponentParameters): Dictionary to convert. + + Returns: + dict: Dictionary of trial component parameters in boto format. + """ + boto_map = {} + for key, value in parameters.items(): + if isinstance(value, numbers.Number): + boto_map[key] = {"NumberValue": value} + else: + boto_map[key] = {"StringValue": str(value)} + return boto_map + + +class TrialComponentArtifact(_base_types.ApiObject): + """Trial component artifact. + + Attributes: + value (str): The artifact value. + media_type (str): The media type. + """ + + value = None + media_type = None + + def __init__(self, value=None, media_type=None, **kwargs): + super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs) + + +class _TrialComponentStatusType(enum.Enum): + """The type of trial component status""" + + InProgress = "InProgress" + Completed = "Completed" + Failed = "Failed" + + +class TrialComponentStatus(_base_types.ApiObject): + """Status of the trial component. + + Attributes: + primary_status (str): The status of a trial component. + message (str): Status message. + """ + + primary_status = None + message = None + + def __init__(self, primary_status=None, message=None, **kwargs): + super(TrialComponentStatus, self).__init__( + primary_status=primary_status, message=message, **kwargs + ) + + +class TrialComponentSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + trial_component_name (str): Name of trial component. + trial_component_arn (str): ARN of the trial component. + display_name (str): Friendly display name in UI. + source_arn (str): ARN of the trial component source. + status (str): Status. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + """ + + _custom_boto_types = { + "status": (TrialComponentStatus, False), + } + trial_component_name = None + trial_component_arn = None + display_name = None + source_arn = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + + +class TrialComponentSource(_base_types.ApiObject): + """Trial Component Source + + Attributes: + source_arn (str): The ARN of the source. + """ + + source_arn = None + + def __init__(self, source_arn=None, **kwargs): + super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs) + + +class Parent(_base_types.ApiObject): + """The trial/experiment/run that a trial component is associated with. + + Attributes: + trial_name (str): Name of the trial. + experiment_name (str): Name of the experiment. + run_name (str): Name of the run. + """ + + trial_name = None + experiment_name = None + run_name = None + + +class TrialComponentSearchResult(_base_types.ApiObject): + """Summary model of an Trial Component search result. + + Attributes: + trial_component_arn (str): ARN of the trial component. + trial_component_name (str): Name of the trial component. + display_name (str): Display name of the trial component for UI display. + source (dict): The source of the trial component. + status (dict): The status of the trial component. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + parameters (dict): The hyperparameters of the component. + input_artifacts (dict): The input artifacts of the component. + output_artifacts (dict): The output artifacts of the component. + metrics (list): The metrics for the component. + source_detail (dict): The source of the trial component. + tags (list): The list of tags that are associated with the trial component. + parents (list[Parent]): The parent of trial component. + """ + + _custom_boto_types = { + "parents": (Parent, True), # parents is a collection (list) of Parent objects + } + trial_component_arn = None + trial_component_name = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + source_detail = None + tags = None + parents = None + + +class TrialSummary(_base_types.ApiObject): + """Summary model of a trial. + + Attributes: + trial_arn (str): The ARN of the trial. + trial_name (str): The name of the trial. + creation_time (datetime): When the trial was created. + last_modified_time (datetime): When the trial was last modified. + """ + + trial_arn = None + trial_name = None + creation_time = None + last_modified_time = None diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py new file mode 100644 index 0000000000..441661ae5a --- /dev/null +++ b/src/sagemaker/experiments/_environment.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the _RunEnvironment class.""" +from __future__ import absolute_import + +import enum +import json +import logging +import os + +from sagemaker.experiments import trial_component +from sagemaker.utils import retry_with_backoff + +TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN" +PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json" +TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH" +MAX_RETRY_ATTEMPTS = 7 + +logger = logging.getLogger(__name__) + + +class _EnvironmentType(enum.Enum): + """SageMaker jobs which data can be pulled from the environment.""" + + SageMakerTrainingJob = 1 + SageMakerProcessingJob = 2 + SageMakerTransformJob = 3 + + +class _RunEnvironment(object): + """Retrieves job specific data from the environment.""" + + def __init__(self, environment_type, source_arn): + """Init for _RunEnvironment. + + Args: + environment_type (_EnvironmentType): The environment type. + source_arn (str): The ARN of the current job. + """ + self.environment_type = environment_type + self.source_arn = source_arn + + @classmethod + def load( + cls, + training_job_arn_env=TRAINING_JOB_ARN_ENV, + processing_job_config_path=PROCESSING_JOB_CONFIG_PATH, + transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR, + ): + """Loads source arn of current job from environment. + + Args: + training_job_arn_env (str): The environment key for training job ARN + (default: `TRAINING_JOB_ARN`). + processing_job_config_path (str): The processing job config path + (default: `/opt/ml/config/processingjobconfig.json`). + transform_job_batch_var (str): The environment variable indicating if + it is a transform job (default: `SAGEMAKER_BATCH`). + + Returns: + _RunEnvironment: Job data loaded from the environment. None if config does not exist. + """ + if training_job_arn_env in os.environ: + environment_type = _EnvironmentType.SageMakerTrainingJob + source_arn = os.environ.get(training_job_arn_env) + return _RunEnvironment(environment_type, source_arn) + if os.path.exists(processing_job_config_path): + environment_type = _EnvironmentType.SageMakerProcessingJob + source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"] + return _RunEnvironment(environment_type, source_arn) + if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true": + environment_type = _EnvironmentType.SageMakerTransformJob + # TODO: need to figure out how to get source_arn from job env + # with Transform team's help. + source_arn = "" + return _RunEnvironment(environment_type, source_arn) + + return None + + def get_trial_component(self, sagemaker_session): + """Retrieves the trial component from the job in the environment. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + _TrialComponent: The trial component created from the job. None if not found. + """ + # TODO: Remove this condition check once we have a way to retrieve source ARN + # from transform job env + if self.environment_type == _EnvironmentType.SageMakerTransformJob: + logger.error( + "Currently getting the job trial component from the transform job environment " + "is not supported. Returning None." + ) + return None + + def _get_trial_component(): + summaries = list( + trial_component._TrialComponent.list( + source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session + ) + ) + if summaries: + summary = summaries[0] + return trial_component._TrialComponent.load( + trial_component_name=summary.trial_component_name, + sagemaker_session=sagemaker_session, + ) + return None + + job_tc = None + try: + job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS) + except Exception as ex: # pylint: disable=broad-except + logger.error( + "Failed to get trail component in the current environment due to %s", str(ex) + ) + return job_tc diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py new file mode 100644 index 0000000000..0c689b1125 --- /dev/null +++ b/src/sagemaker/experiments/_helper.py @@ -0,0 +1,266 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the helper classes for SageMaker Experiment.""" +from __future__ import absolute_import + +import json +import logging +import os + +import botocore + +from sagemaker.experiments._utils import is_already_exist_error + +logger = logging.getLogger(__name__) + + +_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts" +_DEFAULT_ARTIFACT_TYPE = "Tracker" + + +class _ArtifactUploader(object): + """Artifact uploader""" + + def __init__( + self, + trial_component_name, + sagemaker_session, + artifact_bucket=None, + artifact_prefix=_DEFAULT_ARTIFACT_PREFIX, + ): + """Initialize a `_ArtifactUploader` instance. + + Args: + trial_component_name (str): The name of the trial component, + which is used to generate the S3 path to upload the artifact to. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + artifact_bucket (str): The S3 bucket to upload the artifact to. + If not specified, the default bucket defined in `sagemaker_session` + will be used. + artifact_prefix (str): The S3 key prefix used to generate the S3 path + to upload the artifact to (default: "trial-component-artifacts"). + """ + self.sagemaker_session = sagemaker_session + self.trial_component_name = trial_component_name + self.artifact_bucket = artifact_bucket + self.artifact_prefix = artifact_prefix + self._s3_client = self.sagemaker_session.boto_session.client("s3") + + def upload_artifact(self, file_path): + """Upload an artifact file to S3. + + Args: + file_path (str): the file path of the artifact + + Returns: + (str, str): The s3 URI of the uploaded file and the etag of the file. + + Raises: + ValueError: If file does not exist. + """ + file_path = os.path.expanduser(file_path) + if not os.path.isfile(file_path): + raise ValueError( + "{} does not exist or is not a file. Please supply a file path.".format(file_path) + ) + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + artifact_name = os.path.basename(file_path) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None): + """Upload an artifact object to S3. + + Args: + artifact_name (str): the name of the artifact. + artifact_object (obj): the object of the artifact + file_extension (str): Optional file extension. + + Returns: + str: The s3 URI of the uploaded file and the version of the file. + """ + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + if file_extension: + artifact_name = ( + artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension + ) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.put_object( + Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key + ) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def _try_get_etag(self, key): + """Get ETag of given key and return None if not allowed + + Args: + key (str): The S3 object key. + + Returns: + str: The S3 object ETag if it allows, otherwise return None. + """ + try: + response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key) + return response["ETag"] + except botocore.exceptions.ClientError as error: + # requires read permissions + logger.warning("Failed to get ETag of %s due to %s", key, error) + return None + + +class _LineageArtifactManager(object): + """A helper class to manage Lineage Artifacts""" + + def __init__( + self, + name, + source_uri, + etag, + source_arn=None, + dest_arn=None, + artifact_type=_DEFAULT_ARTIFACT_TYPE, + ): + """Initialize a `_LineageArtifactManager` instance. + + Args: + name (str): The name of the Lineage artifact to be created. + source_uri (str): The source URI used to create the Lineage artifact. + etag (str): The S3 Etag used to create the Lineage artifact. + source_arn (str): The source ARN of a trail component to associate + this Lineage artifact with (default: None). + dest_arn (str): The destination ARN of a trial component to associate + this Lineage artifact with (default: None). + artifact_type (str): The type of the Lineage artifact (default: "Tracker"). + """ + self.name = name + self.source_uri = source_uri + self.etag = etag + self.source_arn = source_arn + self.dest_arn = dest_arn + self.artifact_arn = None + self.artifact_type = artifact_type + + def create_artifact(self, sagemaker_session): + """Create the artifact by calling `CreateArtifact` API + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_ids = [] + if self.etag: + source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag}) + + try: + response = sagemaker_session.sagemaker_client.create_artifact( + ArtifactName=self.name, + ArtifactType=self.artifact_type, + Source={"SourceUri": self.source_uri, "SourceTypes": source_ids}, + ) + self.artifact_arn = response["ArtifactArn"] + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip creating the artifact since it already exists: %s", err_info["Message"] + ) + + def add_association(self, sagemaker_session): + """Associate the artifact with a source/destination ARN (e.g. trial component arn) + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_arn = self.source_arn if self.source_arn else self.artifact_arn + dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn + # if the trial component (job) is the source then it produced the artifact, + # otherwise the artifact contributed to the trial component (job) + association_edge_type = "Produced" if self.source_arn else "ContributedTo" + try: + sagemaker_session.sagemaker_client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type + ) + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip associating since the association already exists: %s", err_info["Message"] + ) + + +class _LineageArtifactTracker(object): + """Lineage Artifact Tracker""" + + def __init__(self, trial_component_arn, sagemaker_session): + """Initialize a `_LineageArtifactTracker` instance. + + Args: + trial_component_arn (str): The ARN of the trial component to be + associated with the input/output artifacts. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + self.trial_component_arn = trial_component_arn + self.sagemaker_session = sagemaker_session + self.artifacts = [] + + def add_input_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage input artifact locally + + Args: + name (str): The name of the Lineage input artifact to be added. + source_uri (str): The source URI used to create the Lineage input artifact. + etag (str): The S3 Etag used to create the Lineage input artifact. + artifact_type (str): The type of the Lineage input artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def add_output_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage output artifact locally + + Args: + name (str): The name of the Lineage output artifact to be added. + source_uri (str): The source URI used to create the Lineage output artifact. + etag (str): The S3 Etag used to create the Lineage output artifact. + artifact_type (str): The type of the Lineage output artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def save(self): + """Persist any artifact data saved locally""" + for artifact in self.artifacts: + artifact.create_artifact(self.sagemaker_session) + artifact.add_association(self.sagemaker_session) diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py new file mode 100644 index 0000000000..f80c43f337 --- /dev/null +++ b/src/sagemaker/experiments/_metrics.py @@ -0,0 +1,413 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to manage metrics for Sagemaker Experiment""" +from __future__ import absolute_import + +import datetime +import json +import logging +import os +import time +import threading +import queue + +import dateutil.tz + +from sagemaker.session import Session + +METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".") +METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds +METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds + +BATCH_SIZE = 10 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# TODO: remove this _SageMakerFileMetricsWriter class +# when _MetricsManager is fully ready +class _SageMakerFileMetricsWriter(object): + """Write metric data to file.""" + + def __init__(self, metrics_file_path=None): + """Construct a `_SageMakerFileMetricsWriter` object""" + self._metrics_file_path = metrics_file_path + self._file = None + self._closed = False + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Write a metric to file. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + + Raises: + SageMakerMetricsWriterException: If the metrics file is closed. + AttributeError: If file has been initialized and the writer hasn't been closed. + """ + raw_metric_data = _RawMetricData( + metric_name=metric_name, value=value, timestamp=timestamp, step=step + ) + try: + logger.debug("Writing metric: %s", raw_metric_data) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + except AttributeError as attr_err: + if self._closed: + raise SageMakerMetricsWriterException("log_metric called on a closed writer") + if not self._file: + self._file = open(self._get_metrics_file_path(), "a", buffering=1) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + else: + raise attr_err + + def close(self): + """Closes the metric file.""" + if not self._closed and self._file: + self._file.close() + self._file = None # invalidate reference, causing subsequent log_metric to fail. + self._closed = True + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.close() + + def __del__(self): + """Execute self.close()""" + self.close() + + def _get_metrics_file_path(self): + """Get file path to store metrics""" + pid_filename = "{}.json".format(str(os.getpid())) + metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename) + logger.debug("metrics_file_path = %s", metrics_file_path) + return metrics_file_path + + +class SageMakerMetricsWriterException(Exception): + """SageMakerMetricsWriterException""" + + def __init__(self, message, errors=None): + """Construct a `SageMakerMetricsWriterException` instance""" + super().__init__(message) + if errors: + self.errors = errors + + +class _RawMetricData(object): + """A Raw Metric Data Object""" + + MetricName = None + Value = None + Timestamp = None + Step = None + + def __init__(self, metric_name, value, timestamp=None, step=None): + """Construct a `_RawMetricData` instance. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime or float or str): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + """ + if timestamp is None: + timestamp = time.time() + elif isinstance(timestamp, datetime.datetime): + # If the input is a datetime then convert it to UTC time. + # Assume a naive datetime is in local timezone + if not timestamp.tzinfo: + timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal()) + timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc) + timestamp = timestamp.timestamp() + else: + timestamp = float(timestamp) + + if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > ( + time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW + ): + raise ValueError( + "Supplied timestamp %f is invalid." + " Timestamps must be between two weeks before and two hours from now." % timestamp + ) + value = float(value) + + self.MetricName = metric_name + self.Value = float(value) + self.Timestamp = timestamp + if step is not None: + if not isinstance(step, int): + raise ValueError("step must be int.") + self.Step = step + + def to_record(self): + """Convert the `_RawMetricData` object to dict""" + return self.__dict__ + + def to_raw_metric_data(self): + """Converts the metric data to a BatchPutMetrics RawMetricData item""" + # Convert timestamp from float to timestamp str. + # Otherwise will get ParamValidationError + raw_metric_data = { + "MetricName": self.MetricName, + "Value": self.Value, + "Timestamp": str(int(self.Timestamp)), + } + if self.Step is not None: + raw_metric_data["Step"] = int(self.Step) + return raw_metric_data + + def __str__(self): + """String representation of the `_RawMetricData` object.""" + return repr(self) + + def __repr__(self): + """Return a string representation of this _RawMetricData` object.""" + return "{}({})".format( + type(self).__name__, + ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]), + ) + + +class _MetricsManager(object): + """Collects metrics and sends them directly to SageMaker Metrics data plane APIs.""" + + def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None: + """Initialize a `_MetricsManager` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + sink (object): The metrics sink to use. + """ + if sink is None: + self.sink = _SyncMetricsSink( + trial_component_name, sagemaker_session.sagemaker_metrics_client + ) + else: + self.sink = sink + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Sends a metric to metrics service.""" + + metric_data = _RawMetricData(metric_name, value, timestamp, step) + self.sink.log_metric(metric_data) + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.sink.close() + + def close(self): + """Close the metrics object.""" + self.sink.close() + + +class _SyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_SyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + # this is a simplistic solution which calls BatchPutMetrics + # on the same thread as the client code + self._buffer.append(metric_data) + self._drain() + + def _drain(self, close=False): + """Pops off all metrics in the buffer and starts sending them to metrics service.""" + + if not self._buffer: + return + + if len(self._buffer) < BATCH_SIZE and not close: + return + + # pop all the available metrics + available_metrics, self._buffer = self._buffer, [] + + self._send_metrics(available_metrics) + + def _send_metrics(self, metrics): + """Calls BatchPutMetrics directly on the metrics service.""" + while metrics: + batch, metrics = ( + metrics[:BATCH_SIZE], + metrics[BATCH_SIZE:], + ) + request = self._construct_batch_put_metrics_request(batch) + response = self._metrics_client.batch_put_metrics(**request) + errors = response["Errors"] if "Errors" in response else None + if errors: + message = errors[0]["Message"] + raise Exception(f'{len(errors)} errors with message "{message}"') + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + return { + "TrialComponentName": self._trial_component_name.lower(), + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def close(self): + """Drains any remaining metrics.""" + self._drain(close=True) + + +class _MetricQueue(object): + """A thread safe queue for sending metrics to SageMaker. + + Args: + trial_component_name (str): the ARN of the resource + metric_name (str): the name of the metric + metrics_client (boto_client): the boto client for SageMaker Metrics service + """ + + _CONSUMER_SLEEP_SECONDS = 5 + + def __init__(self, trial_component_name, metric_name, metrics_client): + # infinite queue size + self._queue = queue.Queue() + self._buffer = [] + self._thread = threading.Thread(target=self._run) + self._started = False + self._finished = False + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._metric_name = metric_name + self._logged_metrics = 0 + + def log_metric(self, metric_data): + """Adds a metric data point to the queue""" + self._buffer.append(metric_data) + + if len(self._buffer) < BATCH_SIZE: + return + + self._enqueue_all() + + if not self._started: + self._thread.start() + self._started = True + + def _run(self): + """Starts the metric thread which sends metrics to SageMaker in batches""" + + while not self._queue.empty() or not self._finished: + if self._queue.empty(): + time.sleep(self._CONSUMER_SLEEP_SECONDS) + else: + batch = self._queue.get() + self._send_metrics(batch) + + def _send_metrics(self, metrics_batch): + """Calls BatchPutMetrics directly on the metrics service.""" + request = self._construct_batch_put_metrics_request(metrics_batch) + self._logged_metrics += len(metrics_batch) + self._metrics_client.batch_put_metrics(**request) + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + + return { + "TrialComponentName": self._trial_component_name, + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def _enqueue_all(self): + """Enqueue all buffered metrics to be sent to SageMaker""" + + available_metrics, self._buffer = self._buffer, [] + if available_metrics: + self._queue.put(available_metrics) + + def close(self): + """Flushes any buffered metrics""" + + self._enqueue_all() + self._finished = True + + def is_active(self): + """Is the thread active (still draining metrics to SageMaker)""" + + return self._thread.is_alive() + + +class _AsyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + _COMPLETE_SLEEP_SECONDS = 1.0 + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_AsyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + self._is_draining = False + self._metric_queues = {} + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + if metric_data.MetricName in self._metric_queues: + self._metric_queues[metric_data.MetricName].log_metric(metric_data) + else: + cur_metric_queue = _MetricQueue( + self._trial_component_name, metric_data.MetricName, self._metrics_client + ) + self._metric_queues[metric_data.MetricName] = cur_metric_queue + cur_metric_queue.log_metric(metric_data) + + def close(self): + """Closes the metric file.""" + logging.debug("Closing") + for q in self._metric_queues.values(): + q.close() + + # TODO should probably use join + while any(map(lambda x: x.is_active(), self._metric_queues.values())): + time.sleep(self._COMPLETE_SLEEP_SECONDS) + logging.debug("Closed") diff --git a/src/sagemaker/experiments/_run_context.py b/src/sagemaker/experiments/_run_context.py new file mode 100644 index 0000000000..9a7dada5f4 --- /dev/null +++ b/src/sagemaker/experiments/_run_context.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment _RunContext class.""" +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sagemaker.experiments import Run + + +class _RunContext: + """A static context variable to keep track of the current Run object""" + + _context_run = None + + @classmethod + def add_run_object(cls, run: "Run"): + """Keep track of the current executing Run object + + by adding it to a class static variable. + + Args: + run (Run): The current Run object to be tracked. + """ + cls._context_run = run + + @classmethod + def drop_current_run(cls) -> "Run": + """Drop the Run object tracked in the global static variable + + as its execution finishes (its "with" block ends). + + Return: + Run: the dropped Run object. + """ + current_run = cls._context_run + cls._context_run = None + return current_run + + @classmethod + def get_current_run(cls) -> "Run": + """Return the current Run object without dropping it. + + Return: + Run: the current Run object to be returned. + """ + return cls._context_run diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py new file mode 100644 index 0000000000..5ef5d99dad --- /dev/null +++ b/src/sagemaker/experiments/_utils.py @@ -0,0 +1,218 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment utility methods.""" +from __future__ import absolute_import + +import logging +import os + +import mimetypes +import urllib +from functools import wraps +from typing import Optional + +from sagemaker import Session +from sagemaker.apiutils import _utils +from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression +from sagemaker.utils import retry_with_backoff + + +def resolve_artifact_name(file_path): + """Resolve artifact name from given file path. + + If not specified, will auto create one. + + Args: + file_path (str): Path to the file. + + Returns: + str: The resolved artifact name. + """ + _, filename = os.path.split(file_path) + if filename: + return filename + + return _utils.name("artifact") + + +def guess_media_type(file_path): + """Infer the media type of a file based on its file name. + + Args: + file_path (str): Path to the file. + + Returns: + str: The guessed media type. + """ + file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path)) + guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False) + return guessed_media_type + + +def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_attrs_name): + """Verify if lengths match between lists of true labels and predicted attributes. + + Args: + true_labels (list or array): The list of the true labels. + predicted_attrs (list or array): The list of the predicted labels/probabilities/scores. + predicted_attrs_name (str): The name of the predicted attributes. + + Raises: + ValueError: If lengths mismatch between true labels and predicted attributes. + """ + if len(true_labels) != len(predicted_attrs): + raise ValueError( + "Lengths mismatch between true labels and {}: " + "({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs)) + ) + + +def validate_invoked_inside_run_context(func): + """A Decorator to force the decorated method called under Run context.""" + + @wraps(func) + def wrapper(*args, **kwargs): + self_instance = args[0] + if not self_instance._inside_load_context and not self_instance._inside_init_context: + raise RuntimeError("This method should be called inside context of 'with' statement.") + return func(*args, **kwargs) + + return wrapper + + +def is_already_exist_error(error): + """Check if the error indicates resource already exists + + Args: + error (dict): The "Error" field in the response of the + `botocore.exceptions.ClientError` + """ + return error["Code"] == "ValidationException" and "already exists" in error["Message"] + + +def get_tc_and_exp_config_from_job_env( + environment: _RunEnvironment, + sagemaker_session: Session, +) -> dict: + """Retrieve an experiment config from the job environment. + + Args: + environment (_RunEnvironment): The run environment object with job specific data. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + job_name = environment.source_arn.split("/")[-1] + if environment.environment_type == _EnvironmentType.SageMakerTrainingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_training_job(job_name), + num_attempts=4, + ) + elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_processing_job(job_name), + num_attempts=4, + ) + else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob + raise RuntimeError( + "Failed to load the Run as loading experiment config " + "from transform job environment is not currently supported. " + "As a workaround, please explicitly pass in " + "the experiment_name and run_name in load_run." + ) + + job_exp_config = job_response.get("ExperimentConfig", dict()) + from sagemaker.experiments.run import RUN_NAME + + if job_exp_config.get(RUN_NAME, None): + return job_exp_config + raise RuntimeError( + "Not able to fetch RunName in ExperimentConfig of the sagemaker job. " + "Please make sure the ExperimentConfig is correctly set." + ) + + +def verify_load_input_names( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, +): + """Verify the run_name and the experiment_name inputs in load_run. + + Args: + run_name (str): The run_name supplied by the user (default: None). + experiment_name (str): The experiment_name supplied by the user + (default: None). + + Raises: + ValueError: If run_name is supplied while experiment_name is not. + """ + if not run_name and experiment_name: + logging.warning( + "No run_name is supplied. Ignoring the provided experiment_name " + "since it only takes effect along with run_name. " + "Will load the Run object from the job environment or current Run context." + ) + if run_name and not experiment_name: + raise ValueError( + "Invalid input: experiment_name is missing when run_name is supplied. " + "Please supply a valid experiment_name when the run_name is not None." + ) + + +def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool: + """Check if a trial component is generated by `sagemaker.experiments.Run` + + Args: + trial_component_name (str): The name of the trial component. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + bool: Indicate whether the trial component is created by + `sagemaker.experiments.Run` or not. + """ + search_filter = Filter( + name="TrialComponentName", + operator=Operator.EQUALS, + value=trial_component_name, + ) + search_expression = SearchExpression(filters=[search_filter]) + + def search(): + return list( + _TrialComponent.search( + search_expression=search_expression, + max_results=1, # TrialComponentName is unique in an account + sagemaker_session=sagemaker_session, + ) + )[0] + + try: + tc_search_res = retry_with_backoff(search, 4) + from sagemaker.experiments.run import RUN_TC_TAG + + if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags: + return False + return True + except Exception as ex: # pylint: disable=broad-except + logging.warning( + "Failed to inspect the type of the trial component (%s), due to (%s)", + trial_component_name, + str(ex), + ) + return False diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py new file mode 100644 index 0000000000..8f59ff36b3 --- /dev/null +++ b/src/sagemaker/experiments/experiment.py @@ -0,0 +1,237 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Experiment(_base_types.Record): + """An Amazon SageMaker experiment, which is a collection of related trials. + + New experiments are created by calling `experiments.experiment._Experiment.create`. + Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`. + + Attributes: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. + description (str): A description of the experiment. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment. + """ + + experiment_name = None + display_name = None + description = None + tags = None + + _boto_create_method = "create_experiment" + _boto_load_method = "describe_experiment" + _boto_update_method = "update_experiment" + _boto_delete_method = "delete_experiment" + + _boto_update_members = ["experiment_name", "description", "display_name"] + _boto_delete_members = ["experiment_name"] + + _MAX_DELETE_ALL_ATTEMPTS = 3 + + def save(self): + """Save the state of this Experiment to SageMaker. + + Returns: + dict: Update experiment API response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Experiment from SageMaker. + + Deleting an Experiment does not delete associated Trials and their Trial Components. + It requires that each Trial in the Experiment is first deleted. + + Returns: + dict: Delete experiment API response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, experiment_name, sagemaker_session=None): + """Load an existing experiment and return an `_Experiment` object representing it. + + Args: + experiment_name: (str): Name of the experiment + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_load_method, + experiment_name=experiment_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Create a new experiment in SageMaker and return an `_Experiment` object. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). + description: (str): Description of the experiment (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_create_method, + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Load an experiment by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the + given `experiment_name` does not exist and a new experiment has to be created. + description: (str): Description of the experiment (default: None). + This is used only when the given `experiment_name` does not exist and + a new experiment has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). This is used only when the given `experiment_name` does not + exist and a new experiment has to be created. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + experiment = _Experiment.load(experiment_name, sagemaker_session) + except sagemaker_client.exceptions.ResourceNotFound: + experiment = _Experiment.create( + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return experiment + + def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): + """List trials in this experiment matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialSummary] : + An iterator over trials matching the criteria. + """ + return _Trial.list( + experiment_name=self.experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=self.sagemaker_session, + ) + + def _delete_all(self, action): + """Force to delete the experiment and associated trials, trial components. + + Args: + action (str): The string '--force' is required to pass in to confirm recursively + delete the experiments, and all its trials and trial components. + """ + if action != "--force": + raise ValueError( + "Must confirm with string '--force' in order to delete the experiment and " + "associated trials, trial components." + ) + + delete_attempt_count = 0 + last_exception = None + while True: + if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS: + raise Exception("Failed to delete, please try again.") from last_exception + try: + for trial_summary in self.list_trials(): + trial = _Trial.load( + sagemaker_session=self.sagemaker_session, + trial_name=trial_summary.trial_name, + ) + for ( + trial_component_summary + ) in trial.list_trial_components(): # pylint: disable=no-member + tc = _TrialComponent.load( + sagemaker_session=self.sagemaker_session, + trial_component_name=trial_component_summary.trial_component_name, + ) + tc.delete(force_disassociate=True) + # to prevent throttling + time.sleep(1.2) + trial.delete() # pylint: disable=no-member + # to prevent throttling + time.sleep(1.2) + self.delete() + break + except Exception as ex: # pylint: disable=broad-except + last_exception = ex + finally: + delete_attempt_count = delete_attempt_count + 1 diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py new file mode 100644 index 0000000000..1492b6bafa --- /dev/null +++ b/src/sagemaker/experiments/run.py @@ -0,0 +1,882 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment Run class.""" +from __future__ import absolute_import + +import datetime +import logging +from enum import Enum +from math import isnan, isinf +from numbers import Number +from typing import Optional, List, Dict, TYPE_CHECKING, Union + +import dateutil +from numpy import array + +from sagemaker.apiutils import _utils +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType +from sagemaker.experiments._helper import ( + _ArtifactUploader, + _LineageArtifactTracker, +) +from sagemaker.experiments._environment import _RunEnvironment +from sagemaker.experiments._run_context import _RunContext +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments._metrics import _MetricsManager +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + +from sagemaker.utils import ( + get_module, + unique_name_from_base, +) + +from sagemaker.experiments._utils import ( + guess_media_type, + resolve_artifact_name, + verify_length_of_true_and_predicted, + validate_invoked_inside_run_context, + get_tc_and_exp_config_from_job_env, + verify_load_input_names, + is_run_trial_component, +) + +if TYPE_CHECKING: + from sagemaker import Session + +logger = logging.getLogger(__name__) + +RUN_NAME_BASE = "Sagemaker-Run".lower() +TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}" +MAX_RUN_TC_ARTIFACTS_LEN = 30 +MAX_NAME_LEN_IN_BACKEND = 120 +EXPERIMENT_NAME = "ExperimentName" +TRIAL_NAME = "TrialName" +RUN_NAME = "RunName" +DELIMITER = "-" +RUN_TC_TAG_KEY = "sagemaker:trial-component-source" +RUN_TC_TAG_VALUE = "run" +RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE} + + +class SortByType(Enum): + """The type of property by which to sort the `list_runs` results.""" + + CREATION_TIME = "CreationTime" + NAME = "Name" + + +class SortOrderType(Enum): + """The type of order to sort the list or search results.""" + + ASCENDING = "Ascending" + DESCENDING = "Descending" + + +class Run(object): + """A collection of parameters, metrics, and artifacts to create a ML model.""" + + def __init__( + self, + experiment_name: str, + run_name: Optional[str] = None, + experiment_display_name: Optional[str] = None, + run_display_name: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + sagemaker_session: Optional["Session"] = None, + ): + """Construct a `Run` instance. + + SageMaker Experiments automatically tracks the inputs, parameters, configurations, + and results of your iterations as runs. + You can assign, group, and organize these runs into experiments. + You can also create, compare, and evaluate runs. + + The code sample below shows how to initialize a run, log parameters to the Run object + and invoke a training job under the context of this Run object, which automatically + passes the run's ``experiment_config`` (including the experiment name, run name etc.) + to the training job. + + Note: + All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within + the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown. + + .. code:: python + + with Run(experiment_name="my-exp", run_name="my-run", ...) as run: + run.log_parameter(...) + ... + estimator.fit(job_name="my-job") # Create a training job + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + The code snippet below displays how to load the run initialized above + in a custom training job script, where no ``run_name`` or ``experiment_name`` + is presented as they are automatically retrieved from the experiment config + in the job environment. + + Note: + Instead of the ``Run`` constructor, the ``load_run`` is recommended to use + in a job script to load the existing run created before the job launch. + Otherwise, a new run may be created each time you launch a job. + + .. code:: python + + with load_run() as run: + run.log_metric(...) + ... + + Args: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + run_name (str): The name of the run. If it is not specified, one is auto generated. + experiment_display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. (default: None). This display name is used in + a create experiment call. If an experiment with the specified name already exists, + this display name won't take effect. + run_display_name (str): The display name of the run used in UI (default: None). + This display name is used in a create run call. If a run with the + specified name already exists, this display name won't take effect. + tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + e.g. to create an experiment, a run group, etc. (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + # TODO: we should revert the lower casting once backend fix reaches prod + self.experiment_name = experiment_name.lower() + sagemaker_session = sagemaker_session or _utils.default_session() + self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE) + + # avoid confusion due to mis-match in casing between run name and TC name + self.run_name = self.run_name.lower() + + trial_component_name = Run._generate_trial_component_name( + run_name=self.run_name, experiment_name=self.experiment_name + ) + self.run_group_name = Run._generate_trial_name(self.experiment_name) + + self._experiment = _Experiment._load_or_create( + experiment_name=self.experiment_name, + display_name=experiment_display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial = _Trial._load_or_create( + experiment_name=self.experiment_name, + trial_name=self.run_group_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial_component, is_existed = _TrialComponent._load_or_create( + trial_component_name=trial_component_name, + display_name=run_display_name, + tags=Run._append_run_tc_label_to_tags(tags), + sagemaker_session=sagemaker_session, + ) + if is_existed: + logger.info( + "The run (%s) under experiment (%s) already exists. Loading it. " + "Note: sagemaker.experiments.load_run is recommended to use when " + "the desired run already exists.", + self.run_name, + self.experiment_name, + ) + self._trial.add_trial_component(self._trial_component) + + self._artifact_uploader = _ArtifactUploader( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._lineage_artifact_tracker = _LineageArtifactTracker( + trial_component_arn=self._trial_component.trial_component_arn, + sagemaker_session=sagemaker_session, + ) + self._metrics_manager = _MetricsManager( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._inside_init_context = False + self._inside_load_context = False + self._in_load = False + + @property + def experiment_config(self) -> dict: + """Get experiment config from run attributes.""" + return { + EXPERIMENT_NAME: self.experiment_name, + TRIAL_NAME: self.run_group_name, + RUN_NAME: self._trial_component.trial_component_name, + } + + @validate_invoked_inside_run_context + def log_parameter(self, name: str, value: Union[str, int, float]): + """Record a single parameter value for this run. + + Overwrites any previous value recorded for the specified parameter name. + + Args: + name (str): The name of the parameter. + value (str or int or float): The value of the parameter. + """ + if self._is_input_valid("parameter", name, value): + self._trial_component.parameters[name] = value + + @validate_invoked_inside_run_context + def log_parameters(self, parameters: Dict[str, Union[str, int, float]]): + """Record a collection of parameter values for this run. + + Args: + parameters (dict[str, str or int or float]): The parameters to record. + """ + filtered_parameters = { + key: value + for (key, value) in parameters.items() + if self._is_input_valid("parameter", key, value) + } + self._trial_component.parameters.update(filtered_parameters) + + @validate_invoked_inside_run_context + def log_metric( + self, + name: str, + value: float, + timestamp: Optional[datetime.datetime] = None, + step: Optional[int] = None, + ): + """Record a custom scalar metric value for this run. + + Note: + This method is for manual custom metrics, for automatic metrics see the + ``enable_sagemaker_metrics`` parameter on the ``estimator`` class. + + Args: + name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): The timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): The integer iteration number of the metric value (default: None). + """ + if self._is_input_valid("metric", name, value): + self._metrics_manager.log_metric( + metric_name=name, value=value, timestamp=timestamp, step=step + ) + + @validate_invoked_inside_run_context + def log_precision_recall( + self, + y_true: Union[list, array], + predicted_probabilities: Union[list, array], + positive_label: Optional[Union[str, int]] = None, + title: Optional[str] = None, + is_output: bool = True, + no_skill: Optional[int] = None, + ): + """Create and log a precision recall graph artifact for Studio UI to render. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + predicted_probabilities (list or array): Estimated/predicted probabilities. + positive_label (str or int): Label of the positive class (default: None). + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + no_skill (int): The precision threshold under which the classifier cannot discriminate + between the classes and would predict a random class or a constant class in + all cases (default: None). + """ + + verify_length_of_true_and_predicted( + true_labels=y_true, + predicted_attrs=predicted_probabilities, + predicted_attrs_name="predicted probabilities", + ) + + get_module("sklearn") + from sklearn.metrics import precision_recall_curve, average_precision_score + + kwargs = {} + if positive_label is not None: + kwargs["pos_label"] = positive_label + + precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs) + + kwargs["average"] = "micro" + ap = average_precision_score(y_true, predicted_probabilities, **kwargs) + + data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": precision.tolist(), + "recall": recall.tolist(), + "averagePrecisionScore": ap, + "noSkill": no_skill, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_roc_curve( + self, + y_true: Union[list, array], + y_score: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a receiver operating characteristic (ROC curve) artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_score (list or array): Estimated/predicted probabilities. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores" + ) + + get_module("sklearn") + from sklearn.metrics import roc_curve, auc + + fpr, tpr, _ = roc_curve(y_true, y_score) + + auc = auc(fpr, tpr) + + data = { + "type": "ROCCurve", + "version": 0, + "title": title, + "falsePositiveRate": fpr.tolist(), + "truePositiveRate": tpr.tolist(), + "areaUnderCurve": auc, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_confusion_matrix( + self, + y_true: Union[list, array], + y_pred: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a confusion matrix artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the + artifact by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_pred (list or array): Predicted labels. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels" + ) + + get_module("sklearn") + from sklearn.metrics import confusion_matrix + + matrix = confusion_matrix(y_true, y_pred) + + data = { + "type": "ConfusionMatrix", + "version": 0, + "title": title, + "confusionMatrix": matrix.tolist(), + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_artifact( + self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True + ): + """Record a single artifact for this run. + + Overwrites any previous value recorded for the specified name. + + Args: + name (str): The name of the artifact. + value (str): The value. + media_type (str): The MediaType (MIME type) of the value (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output=is_output) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + + @validate_invoked_inside_run_context + def log_file( + self, + file_path: str, + name: Optional[str] = None, + media_type: Optional[str] = None, + is_output: bool = True, + ): + """Upload a file to s3 and store it as an input/output artifact in this run. + + Args: + file_path (str): The path of the local file to upload. + name (str): The name of the artifact (default: None). + media_type (str): The MediaType (MIME type) of the file. + If not specified, this library will attempt to infer the media type + from the file extension of ``file_path``. + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output) + media_type = media_type or guess_media_type(file_path) + name = name or resolve_artifact_name(file_path) + s3_uri, _ = self._artifact_uploader.upload_artifact(file_path) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + + def close(self): + """Persist any data saved locally.""" + try: + # Update the trial component with additions from the Run object + self._trial_component.save() + # Create Lineage entities for the artifacts + self._lineage_artifact_tracker.save() + finally: + if self._metrics_manager: + self._metrics_manager.close() + + @staticmethod + def _generate_trial_name(base_name) -> str: + """Generate the reserved trial name based on experiment name + + Args: + base_name (str): The ``experiment_name`` of this ``Run`` object. + """ + available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE) + return TRIAL_NAME_TEMPLATE.format(base_name[:available_length]) + + @staticmethod + def _is_input_valid(input_type, field_name, field_value) -> bool: + """Check if the input is valid or not + + Args: + input_type (str): The type of the input, one of ``parameter``, ``metric``. + field_name (str): The name of the field to be checked. + field_value (str or int or float): The value of the field to be checked. + """ + if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): + logger.warning( + "Failed to log %s %s. Received invalid value: %s.", + input_type, + field_name, + field_value, + ) + return False + return True + + def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None): + """Log an artifact. + + Logs an artifact by uploading data to S3, creating an artifact, and associating that + artifact with the run trial component. + + Args: + data (dict): Artifacts data that will be saved to S3. + graph_type (str): The type of the artifact. + is_output (bool): Determines direction of association to the + trial component. Defaults to True (output artifact). + If set to False then represented as input association. + artifact_name (str): Name of the artifact (default: None). + """ + # generate an artifact name + if not artifact_name: + unique_name_from_base(graph_type) + + # create a json file in S3 + s3_uri, etag = self._artifact_uploader.upload_object_artifact( + artifact_name, data, file_extension="json" + ) + + # create an artifact and association for the table + if is_output: + self._lineage_artifact_tracker.add_output_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + else: + self._lineage_artifact_tracker.add_input_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + + def _verify_trial_component_artifacts_length(self, is_output): + """Verify the length of trial component artifacts + + Args: + is_output (bool): Determines direction of association to the + trial component. + + Raises: + ValueError: If the length of trial component artifacts exceeds the limit. + """ + err_msg_template = "Cannot add more than {} {}_artifacts under run" + if is_output: + if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")) + else: + if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")) + + @staticmethod + def _generate_trial_component_name(run_name: str, experiment_name: str) -> str: + """Generate the TrialComponentName based on run_name and experiment_name + + Args: + run_name (str): The run_name supplied by the user. + experiment_name (str): The experiment_name supplied by the user, + which is prepended to the run_name to generate the TrialComponentName. + + Returns: + str: The TrialComponentName used to create a trial component + which is unique in an account. + + Raises: + ValueError: If either the run_name or the experiment_name exceeds + the length limit. + """ + buffer = 1 # leave length buffers for delimiters + max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer + err_msg_template = "The {} (length: {}) must have length less than or equal to {}" + if len(run_name) > max_len: + raise ValueError(err_msg_template.format("run_name", len(run_name), max_len)) + if len(experiment_name) > max_len: + raise ValueError( + err_msg_template.format("experiment_name", len(experiment_name), max_len) + ) + trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name) + # due to mixed-case concerns on the backend + trial_component_name = trial_component_name.lower() + return trial_component_name + + @staticmethod + def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str: + """Extract the user supplied run name from a trial component name. + + Args: + trial_component_name (str): The name of a run trial component. + experiment_name (str): The experiment_name supplied by the user, + which was prepended to the run_name to generate the trial_component_name. + + Returns: + str: The name of the Run object supplied by a user. + """ + return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + + @staticmethod + def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: + """Append the run trial component label to tags used to create a trial component. + + Args: + tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object. + + Returns: + list: The updated tags with the appended run trial component label. + """ + if not tags: + tags = [] + tags.append(RUN_TC_TAG) + return tags + + def __enter__(self): + """Updates the start time of the run. + + Returns: + object: self. + """ + nested_with_err_msg_template = ( + "It is not allowed to use nested 'with' statements on the {}." + ) + if self._in_load: + if self._inside_load_context: + raise RuntimeError(nested_with_err_msg_template.format("load_run")) + self._inside_load_context = True + else: + if _RunContext.get_current_run(): + raise RuntimeError(nested_with_err_msg_template.format("Run")) + self._inside_init_context = True + _RunContext.add_run_object(self) + + if not self._trial_component.start_time: + start_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.start_time = start_time + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, + message="Within a run context", + ) + # Save the start_time and status changes to backend + self._trial_component.save() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Updates the end time of the run. + + Args: + exc_type (str): The exception type. + exc_value (str): The exception value. + exc_traceback (str): The stack trace of the exception. + """ + if self._in_load: + self._inside_load_context = False + self._in_load = False + else: + self._inside_init_context = False + _RunContext.drop_current_run() + + end_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.end_time = end_time + if exc_value: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value) + ) + else: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Completed.value + ) + + self.close() + + +def load_run( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, + sagemaker_session: Optional["Session"] = None, +) -> Run: + """Load an existing run. + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + It can be used in several ways: + + 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``. + + If ``run_name`` and ``experiment_name`` are passed in, they are honored over + the default experiment config in the job environment or the run context + (i.e. within the ``with`` block). + + Note: + Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work. + Otherwise, you may get a ``ValueError``. + + .. code:: python + + with load_run(experiment_name="my-exp", run_name="my-run") as run: + run.log_metric(...) + ... + + 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``. + + In this case, the default experiment config (specified when creating the job) is fetched + from the job environment to load the run. + + .. code:: python + + # In a job script + with load_run() as run: + run.log_metric(...) + ... + + 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block) + but without supplying ``run_name`` and ``experiment_name``. + + Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked + in the run context. Then when we call ``load_run()`` under this with statement, the ``run1`` + in the context is loaded by default. + + .. code:: python + + # In a notebook + with Run(experiment_name="my-exp", run_name="my-run", ...) as run1: + run1.log_parameter(...) + + with load_run() as run2: # run2 is the same object as run1 + run2.log_metric(...) + ... + + Args: + run_name (str): The name of the run to be loaded (default: None). + If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be + fetched to load the run. + experiment_name (str): The name of the Experiment that the to be loaded run + is associated with (default: None). + Note: the experiment_name must be supplied along with a valid run_name. + Otherwise, it will be ignored. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + Run: The loaded Run object. + """ + sagemaker_session = sagemaker_session or _utils.default_session() + environment = _RunEnvironment.load() + + verify_load_input_names(run_name=run_name, experiment_name=experiment_name) + + if run_name or environment: + if run_name: + logger.warning( + "run_name is explicitly supplied in load_run, " + "which will be prioritized to load the Run object. " + "In other words, the run name in the experiment config, fetched from the " + "job environment or the current run context, will be ignored." + ) + else: + exp_config = get_tc_and_exp_config_from_job_env( + environment=environment, sagemaker_session=sagemaker_session + ) + run_name = Run._extract_run_name_from_tc_name( + trial_component_name=exp_config[RUN_NAME], + experiment_name=exp_config[EXPERIMENT_NAME], + ) + experiment_name = exp_config[EXPERIMENT_NAME] + + run_instance = Run( + experiment_name=experiment_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) + elif _RunContext.get_current_run(): + run_instance = _RunContext.get_current_run() + else: + raise RuntimeError( + "Failed to load a Run object. " + "Please make sure a Run object has been initialized already." + ) + + run_instance._in_load = True + return run_instance + + +def list_runs( + experiment_name: str, + created_before: Optional[datetime.datetime] = None, + created_after: Optional[datetime.datetime] = None, + sagemaker_session: Optional["Session"] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + sort_by: SortByType = SortByType.CREATION_TIME, + sort_order: SortOrderType = SortOrderType.DESCENDING, +) -> list: + """Return a list of ``Run`` objects matching the given criteria. + + Args: + experiment_name (str): Only Run objects related to the specified experiment + are returned. + created_before (datetime.datetime): Return Run objects created before this instant + (default: None). + created_after (datetime.datetime): Return Run objects created after this instant + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + max_results (int): Maximum number of Run objects to retrieve (default: None). + next_token (str): Token for next page of results (default: None). + sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME + (default: CREATION_TIME). + sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING). + + Returns: + list: A list of ``Run`` objects. + """ + tc_summaries = _TrialComponent.list( + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by.value, + sort_order=sort_order.value, + sagemaker_session=sagemaker_session, + max_results=max_results, + next_token=next_token, + ) + run_list = [] + for tc_summary in tc_summaries: + if not is_run_trial_component( + trial_component_name=tc_summary.trial_component_name, + sagemaker_session=sagemaker_session, + ): + continue + run_instance = Run( + experiment_name=experiment_name, + run_name=Run._extract_run_name_from_tc_name( + trial_component_name=tc_summary.trial_component_name, + experiment_name=experiment_name, + ), + sagemaker_session=sagemaker_session, + ) + run_list.append(run_instance) + return run_list diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py new file mode 100644 index 0000000000..146b24f18b --- /dev/null +++ b/src/sagemaker/experiments/trial.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the Trial class.""" +from __future__ import absolute_import + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Trial(_base_types.Record): + """An execution of a data-science workflow with an experiment. + + Consists of a list of trial component objects, which document individual + activities within the workflow. + + Attributes: + trial_name (str): The name of the trial. + experiment_name (str): The name of the trial's experiment. + display_name (str): The name of the trial that will appear in UI, + such as SageMaker Studio. + tags (List[Dict[str, str]]): A list of tags to associate with the trial. + """ + + trial_name = None + experiment_name = None + display_name = None + tags = None + + _boto_create_method = "create_trial" + _boto_load_method = "describe_trial" + _boto_delete_method = "delete_trial" + _boto_update_method = "update_trial" + + _boto_update_members = ["trial_name", "display_name"] + _boto_delete_members = ["trial_name"] + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_Trial, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this Trial to SageMaker. + + Returns: + dict: Update trial response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Trial from SageMaker. + + Does not delete associated Trial Components. + + Returns: + dict: Delete trial response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_name, sagemaker_session=None): + """Load an existing trial and return a `_Trial` object. + + Args: + trial_name: (str): Name of the Trial. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + return super(_Trial, cls)._construct( + cls._boto_load_method, + trial_name=trial_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Create a new trial and return a `_Trial` object. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). + tags (List[dict]): A list of tags to associate with the trial (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + trial = super(_Trial, cls)._construct( + cls._boto_create_method, + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial + + @classmethod + def list( + cls, + experiment_name=None, + trial_component_name=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + ): + """List all trials matching the specified criteria. + + Args: + experiment_name (str): Name of the experiment. If specified, only trials in + the experiment will be returned (default: None). + trial_component_name (str): Name of the trial component. If specified, only + trials with this trial component name will be returned (default: None). + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + collections.Iterator[experiments._api_types.TrialSummary]: An iterator over trials + matching the specified criteria. + """ + return super(_Trial, cls)._list( + "list_trials", + _api_types.TrialSummary.from_boto, + "TrialSummaries", + experiment_name=experiment_name, + trial_component_name=trial_component_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + + def add_trial_component(self, trial_component): + """Add the specified trial component to this trial. + + A trial component may belong to many trials and a trial may have many trial components. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.associate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def remove_trial_component(self, trial_component): + """Remove the specified trial component from this trial. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def list_trial_components( + self, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + max_results=None, + next_token=None, + ): + """List trial components in this trial matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', + 'CreationTime' (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary] : An iterator over + trials matching the criteria. + """ + return _TrialComponent.list( + trial_name=self.trial_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + next_token=next_token, + sagemaker_session=self.sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the given + `trial_name` does not exist and a new trial has to be created. + tags (List[dict]): A list of tags to associate with the trial (default: None). + This is used only when the given `trial_name` does not exist and + a new trial has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + trial = _Trial.load(trial_name, sagemaker_session) + if trial.experiment_name != experiment_name: # pylint: disable=no-member + raise ValueError( + "The given experiment_name {} ".format(experiment_name) + + "does not match that in the loaded trial {}".format( + trial.experiment_name # pylint: disable=no-member + ) + ) + except sagemaker_client.exceptions.ResourceNotFound: + trial = _Trial.create( + experiment_name=experiment_name, + trial_name=trial_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py new file mode 100644 index 0000000000..e5701b2119 --- /dev/null +++ b/src/sagemaker/experiments/trial_component.py @@ -0,0 +1,341 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the TrialComponent class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentSearchResult + + +class _TrialComponent(_base_types.Record): + """This class represents a SageMaker trial component object. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and + may not be created directly. To automatically associate trial components with + a trial and experiment, supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker + from the name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[Dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_load_method = "describe_trial_component" + _boto_create_method = "create_trial_component" + _boto_update_method = "update_trial_component" + _boto_delete_method = "delete_trial_component" + + _custom_boto_types = { + "source": (_api_types.TrialComponentSource, False), + "status": (_api_types.TrialComponentStatus, False), + "parameters": (_api_types.TrialComponentParameters, False), + "input_artifacts": (_api_types.TrialComponentArtifact, True), + "output_artifacts": (_api_types.TrialComponentArtifact, True), + "metrics": (_api_types.TrialComponentMetricSummary, True), + } + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + def __init__(self, sagemaker_session=None, **kwargs): + """Init for _TrialComponent""" + super().__init__(sagemaker_session, **kwargs) + self.parameters = self.parameters or {} + self.input_artifacts = self.input_artifacts or {} + self.output_artifacts = self.output_artifacts or {} + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this TrialComponent to SageMaker.""" + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self, force_disassociate=False): + """Delete this TrialComponent from SageMaker. + + Args: + force_disassociate (boolean): Indicates whether to force disassociate the + trial component with the trials before deletion (default: False). + If set to true, force disassociate the trial component with associated trials + first, then delete the trial component. + If it's not set or set to false, it will delete the trial component directory + without disassociation. + + Returns: + dict: Delete trial component response. + """ + if force_disassociate: + next_token = None + + while True: + if next_token: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name, NextToken=next_token + ) + else: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name + ) + + # Disassociate the trials and trial components + for per_trial in list_trials_response["TrialSummaries"]: + # to prevent DisassociateTrialComponent throttling + time.sleep(1.2) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=per_trial["TrialName"], + TrialComponentName=self.trial_component_name, + ) + + if "NextToken" in list_trials_response: + next_token = list_trials_response["NextToken"] + else: + break + + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_component_name, sagemaker_session=None): + """Load an existing trial component and return an `_TrialComponent` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + @classmethod + def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None): + """Create a trial component and return a `_TrialComponent` object representing it. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + """ + return super(_TrialComponent, cls)._construct( + cls._boto_create_method, + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def list( + cls, + source_arn=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + trial_name=None, + experiment_name=None, + max_results=None, + next_token=None, + ): + """Return a list of trial component summaries. + + Args: + source_arn (str): A SageMaker Training or Processing Job ARN (default: None). + created_before (datetime.datetime): Return trial components created before this instant + (default: None). + created_after (datetime.datetime): Return trial components created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + trial_name (str): If provided only trial components related to the trial are returned + (default: None). + experiment_name (str): If provided only trial components related to the experiment are + returned (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator + over `TrialComponentSummary` objects. + """ + return super(_TrialComponent, cls)._list( + "list_trial_components", + _api_types.TrialComponentSummary.from_boto, + "TrialComponentSummaries", + source_arn=source_arn, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + max_results=max_results, + next_token=next_token, + ) + + @classmethod + def search( + cls, + search_expression=None, + sort_by=None, + sort_order=None, + max_results=None, + sagemaker_session=None, + ): + """Search Experiment Trail Component. + + Returns SearchResults in the account matching the search criteria. + + Args: + search_expression: (SearchExpression): A Boolean conditional statement (default: None). + Resource objects must satisfy this condition to be included in search results. + You must provide at least one subexpression, filter, or nested filter. + sort_by (str): The name of the resource property used to sort the SearchResults + (default: None). + sort_order (str): How SearchResults are ordered. Valid values are Ascending or + Descending (default: None). + max_results (int): The maximum number of results to return in a SearchResponse + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + collections.Iterator[SearchResult] : An iterator over search results matching the + search criteria. + """ + return super(_TrialComponent, cls)._search( + search_resource="ExperimentTrialComponent", + search_item_factory=TrialComponentSearchResult.from_boto, + search_expression=None if search_expression is None else search_expression.to_boto(), + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial component by name and create a new one if it does not exist. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + bool: A boolean variable indicating whether the trail component already exists + """ + sagemaker_client = sagemaker_session.sagemaker_client + is_existed = False + try: + run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) + is_existed = True + except sagemaker_client.exceptions.ResourceNotFound: + run_tc = _TrialComponent.create( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return run_tc, is_existed diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py index 28732b0174..7c833a468e 100644 --- a/src/sagemaker/lineage/_utils.py +++ b/src/sagemaker/lineage/_utils.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """SageMaker lineage utility methods.""" from __future__ import absolute_import -from importlib import import_module from sagemaker.lineage import association @@ -38,22 +37,6 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None) curr_association.delete() -def get_module(module_name): - """Import a module. - - Args: - module_name (str): name of the module to import. - - Returns: - [obj]: The imported module. - Raises exceptions when the module name is not found - """ - try: - return import_module(module_name) - except ImportError: - raise Exception("Cannot import module {}, please try again.".format(module_name)) - - def get_resource_name_from_arn(arn): """Extract the resource name from an ARN string. diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index 3921562beb..718344095a 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -29,8 +29,9 @@ LineageEntityEnum, LineageQueryDirectionEnum, ) -from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn +from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn from sagemaker.lineage.association import Association +from sagemaker.utils import get_module LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 01d4361197..af52da6288 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -33,7 +33,12 @@ from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig -from sagemaker.utils import base_name_from_image, get_config_value, name_from_base +from sagemaker.utils import ( + base_name_from_image, + get_config_value, + name_from_base, + check_and_get_run_experiment_config, +) from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join @@ -203,6 +208,7 @@ def run( outputs=outputs, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, @@ -605,6 +611,7 @@ def run( kms_key=kms_key, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 72df570496..ce6a3b99cd 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -89,6 +89,7 @@ def __init__( sagemaker_featurestore_runtime_client=None, default_bucket=None, settings=SessionSettings(), + sagemaker_metrics_client=None, ): """Initialize a SageMaker ``Session``. @@ -116,6 +117,10 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. + sagemaker_metrics_client (boto3.SageMakerMetrics.Client): + Client which makes SageMaker Metrics related calls to Amazon SageMaker + (default: None). If not provided, one will be created using + this instance's ``boto_session``. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -130,6 +135,7 @@ def __init__( sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, + sagemaker_metrics_client=sagemaker_metrics_client, ) def _initialize( @@ -138,6 +144,7 @@ def _initialize( sagemaker_client, sagemaker_runtime_client, sagemaker_featurestore_runtime_client, + sagemaker_metrics_client, ): """Initialize this SageMaker Session. @@ -172,6 +179,12 @@ def _initialize( "sagemaker-featurestore-runtime" ) + if sagemaker_metrics_client: + self.sagemaker_metrics_client = sagemaker_metrics_client + else: + self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") + prepend_user_agent(self.sagemaker_metrics_client) + self.local_mode = False @property @@ -548,8 +561,8 @@ def train( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -558,6 +571,7 @@ def train( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries @@ -703,8 +717,8 @@ def _get_train_request( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -713,6 +727,7 @@ def _get_train_request( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 97278abdd0..40ed143ebc 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -27,7 +27,11 @@ from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.utils import base_name_from_image, name_from_base +from sagemaker.utils import ( + base_name_from_image, + name_from_base, + check_and_get_run_experiment_config, +) class Transformer(object): @@ -251,6 +255,7 @@ def transform( ) self._reset_output_path = True + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_transform_job = _TransformJob.start_new( self, data, diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py new file mode 100644 index 0000000000..5b2aaf3226 --- /dev/null +++ b/src/sagemaker/utilities/search_expression.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Simplify Search Expression by provide a simplified DSL""" +from __future__ import absolute_import + +from enum import Enum, unique + +from sagemaker.apiutils._base_types import ApiObject + + +# TODO: we should update the lineage to use search expressions +# defined here in a separate change +@unique +class Operator(Enum): + """Search operators""" + + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + + +@unique +class BooleanOperator(Enum): + """Boolean search operation enum""" + + AND = "And" + OR = "Or" + + +class SearchObject(ApiObject): + """Search Object""" + + def to_boto(self): + """Convert a search object to boto""" + return ApiObject.to_boto(self) + + +class Filter(SearchObject): + """A Python class represent a Search Filter object.""" + + name = None + operator = None + value = None + + def __init__(self, name, operator=None, value=None, **kwargs): + """Construct a Filter object + + Args: + name (str): filter field name + operator (Operator): one of Operator enum + value (str): value of the field + """ + super().__init__(**kwargs) + self.name = name + self.operator = None if operator is None else operator.value + self.value = value + + +class NestedFilter(SearchObject): + """A Python class represent a Nested Filter object.""" + + nested_property_name = None + filters = None + + def __init__(self, property_name, filters, **kwargs): + """Construct a Nested Filter object + + Args: + property_name (str): nested property name + filters (List[Filter]): list of Filter objects + """ + super().__init__(**kwargs) + self.nested_property_name = property_name + self.filters = list(map(lambda x: x.to_boto(), filters)) + + +class SearchExpression(SearchObject): + """A Python class representation of a Search Expression object. + + A sample search expression defined in here: + https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search + """ + + filters = None + nested_filters = None + operator = None + sub_expressions = None + + def __init__( + self, + filters=None, + nested_filters=None, + sub_expressions=None, + boolean_operator=BooleanOperator.AND, + **kwargs + ): + """Construct a Search Expression object + + Args: + filters (List[Filter]): list of Filter objects + nested_filters (List[NestedFilter]): list of Nested Filters objects + sub_expressions (List[SearchExpression]): list of Search Expression objects + boolean_operator (BooleanOperator): one of the boolean operator enums + """ + super().__init__(**kwargs) + if filters is None and nested_filters is None and sub_expressions is None: + raise ValueError( + "You must specify at least one subexpression, filter, or nested filter" + ) + self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters)) + self.nested_filters = ( + None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters)) + ) + self.sub_expressions = ( + None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions)) + ) + self.operator = boolean_operator.value diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e668b2a8ed..9d28e3bf4e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,6 +29,7 @@ from datetime import datetime from typing import Optional +from importlib import import_module import botocore from six.moves.urllib import parse @@ -590,6 +591,27 @@ def retries( ) +def retry_with_backoff(callable_func, num_attempts=8): + """Retry with backoff until maximum attempts are reached + + Args: + callable_func (callable): The callable function to retry. + num_attempts (int): The maximum number of attempts to retry. + """ + if num_attempts < 1: + raise ValueError( + "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts) + ) + for i in range(num_attempts): + try: + return callable_func() + except Exception as ex: # pylint: disable=broad-except + if i == num_attempts - 1: + raise ex + logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex)) + time.sleep(2**i) + + def _botocore_resolver(): """Get the DNS suffix for the given region. @@ -874,3 +896,47 @@ def _start_waiting(waiting_time: int): print(progress, end="\r") time.sleep(interval) print(len(progress) * " ", end="\r") + + +def get_module(module_name): + """Import a module. + + Args: + module_name (str): name of the module to import. + + Returns: + object: The imported module. + + Raises: + Exception: when the module name is not found + """ + try: + return import_module(module_name) + except ImportError: + raise Exception("Cannot import module {}, please try again.".format(module_name)) + + +def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict: + """Check user input experiment_config or get it from the current Run object if exists. + + Args: + experiment_config (dict): The experiment_config supplied by the user. + + Returns: + dict: Return the user supplied experiment_config if it is not None. + Otherwise fetch the experiment_config from the current Run object if exists. + """ + from sagemaker.experiments._run_context import _RunContext + + run_obj = _RunContext.get_current_run() + if experiment_config: + if run_obj: + logger.warning( + "The function is invoked within an Experiment Run context " + "but another experiment_config (%s) was supplied, so " + "ignoring the experiment_config fetched from the Run object.", + experiment_config, + ) + return experiment_config + + return run_obj.experiment_config if run_obj else None diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py new file mode 100644 index 0000000000..cdb9a7b8c6 --- /dev/null +++ b/tests/data/experiment/inference.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +import logging +import os +import pickle as pkl + +import boto3 +import numpy as np +import sagemaker_xgboost_container.encoder as xgb_encoders + +sdk_name = "sagemaker-dev-1.0.tar.gz" +code_dir = "/opt/ml/code" + +sdk_file = f"{code_dir}/{sdk_name}" +os.system(f"pip install {sdk_file}") + +from sagemaker.session import Session +from sagemaker.experiments import load_run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +def model_fn(model_dir): + """ + Deserialize and return fitted model. + """ + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) + + model_file = "xgboost-model" + booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) + return booster + + +def input_fn(request_body, request_content_type): + """ + The SageMaker XGBoost model server receives the request data body and the content type, + and invokes the `input_fn`. + Return a DMatrix (an object that can be passed to predict_fn). + """ + if request_content_type == "text/libsvm": + return xgb_encoders.libsvm_to_dmatrix(request_body) + else: + raise ValueError("Content type {} is not supported.".format(request_content_type)) + + +def predict_fn(input_data, model): + """ + SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. + Return a two-dimensional NumPy array where the first columns are predictions + and the remaining columns are the feature contributions (SHAP values) for that prediction. + """ + prediction = model.predict(input_data) + feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) + output = np.hstack((prediction[:, np.newaxis], feature_contribs)) + return output + + +def output_fn(predictions, content_type): + """ + After invoking predict_fn, the model server invokes `output_fn`. + """ + if content_type == "text/csv" or content_type == "application/json": + return ",".join(str(x) for x in predictions[0]) + else: + raise ValueError("Content type {} is not supported.".format(content_type)) diff --git a/tests/data/experiment/process_job_script_for_run_clz.py b/tests/data/experiment/process_job_script_for_run_clz.py new file mode 100644 index 0000000000..32fd0ab4f6 --- /dev/null +++ b/tests/data/experiment/process_job_script_for_run_clz.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker processing job""" +from __future__ import absolute_import + +import logging +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + + +from sagemaker import Session +from sagemaker.experiments import load_run + + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +with load_run(sagemaker_session=sagemaker_session) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) diff --git a/tests/data/experiment/train_job_script_for_run_clz.py b/tests/data/experiment/train_job_script_for_run_clz.py new file mode 100644 index 0000000000..34c86e0993 --- /dev/null +++ b/tests/data/experiment/train_job_script_for_run_clz.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker training job""" +from __future__ import absolute_import + +import logging +import time +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + +from sagemaker import Session +from sagemaker.experiments import load_run, Run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + +if os.environ["RUN_OPERATION"] == "init": + logging.info("Initializing a Run") + with Run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + +else: + logging.info("Loading a Run") + logging.info("Invoking load_run with name arguments") + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + run.log_parameters({"p3": 3.0, "p4": 4}) + run.log_metric("test-job-load-log-metric", 0.1) + + if os.environ.get("CALL_RUN_LOAD_WITH_NO_NAME_ARGS", None) == "True": + logging.info("Invoking load_run without name arguments") + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameters({"p5": 5.0, "p6": 6}) diff --git a/tests/data/experiment/transform_job_materials/data.csv b/tests/data/experiment/transform_job_materials/data.csv new file mode 100644 index 0000000000..9f1b6c0bb0 --- /dev/null +++ b/tests/data/experiment/transform_job_materials/data.csv @@ -0,0 +1 @@ +-99 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067 \ No newline at end of file diff --git a/tests/data/experiment/transform_job_materials/xgb_model.tar.gz b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..3969bede9e315f8f51d27f3df2de623e670459c6 GIT binary patch literal 35946 zcmV(%K;pk2iwFP!00000|Lncnj+{r6B-pdF*o%IOhOTA+W+6Pz(U-H}<=!vQXb_ZC zlGsC$`bny%dm0VwKGeM1Uap(Fd1O{G>s&nigAq$)Rvx}nei32rZf^E3zyA3C{l`y- z-{1dy`Sx$V%zsHz>b3q|^8c>>TSgtZ{+-lZavP&c|GOl))bTfem%h;PT>0UA{(t}SAO8I>|J#51 zzyA+?+i$Pm{rvXwFaPnUAOC#w_S2hpAOH5pfBkg%`oo9U|N6Io`QQJ`|M>s@!{7Yd z5C7-;cfY*(^@qRzw;$f>OYbf};Nh>A`ryq^ul{)b;q~u$;_}_=AKzZSy8M?v^!eW} z-+g-h_SHXqeE;s%NB#Zv+c*FH^`|%Q-~I6SKgbU+e)#3}o42p@wKwnnSzr43)vtg2 z`RdKj`eM=-_b&(GZ-2c0_43unH~;;T?|6p~eB8hM)B7LsOT{mre*19w`e*+3K~L}f z@2~V$_!J-gczOBr$-`HAw6EUim5HOn3wBKU5kI+0ef6Uq_rL4We0cry8$G#KKl%l= zOY!5U50{s(K7Dxo=H2mny!-Qa`{nK1%eSxa2A=vKA1;4>^V6sI`q3X=|M}^?J@x*l z%a0#lefrlw@UEXPe|i1q+fVv}ZC4NAxA!>a%YS+O6Q2F^4;&mm|LceM|LxUJ@8ACZ z`oCU&c=5-FKi^#*U)Kx&mmA)rC-P2D46oX;<6``A`O}+U-slzCoB!qY+mDwgyZg7R z+rks^`1;Ae|KroEpD*8iyu`cX+Fg6k>$iXW_L^7xr`JFIcKJ#_#fzk$`uO_yKlFb3 z@n7Fw{`2MAGta-%sqM_uCzqNR^3~68K3x99@7Hc&eT}#4FPE=B{rTbYm7dw3f4q73 zE05>$lO5Sb=CCig5nuoQ`)l4_o5}X3=l<@a zo%AQ%KlwkIm;~N}tZk@mD_I}Y_e%hbl^E$CtltGDkz##sCP{k!*{-oLxNcAo#~KV#%w{_^4S<8RpOFFfB5m%qL`{m$|G zl?!5DAuhJK0OL{?6M_ zpIJYgZ0*{J*~{v2W%~KeN4t;XR>zs(`YHP0=6|+RJg%qjfBYv5hJVJ+^eg!LkAL~| zwXV9CA3o^m>%RTNdT{^n!-wDXmj1^pANjXf8$bIfzfiqBjtBhlr`LLHf4u*Ayqn|) zyI%jSu)zl8Kgf&nqWq&i^?%{gT?10DKk32z_~$3wxc_zo!uUNtN2qb_E5PN&^SIp2 zkc~nmyDwk;^Ch81@)cM={N~*+dJp|Wdhx~o{G;y?7n=8pUn4xRE@~~kw)Qf;7#BI$ zEO;mvX{J?Wy_6U2qKz(N=+pBerQGZCca&LH?fMS6$WwHL;BmQFQpd=6d)KY#M!IcV zLlE0Sm7HvM>!Ryxg}WO=!;@kgYpu5|!ye0;qwCw{MZW0vA2wVsMC$DLXw!q%P0wok zE3K!+L(O^uZQ0Yhh0;g0r&k-ZRXe1i8=tB20I(B%xwbt01an?lKWN91q^wf>u~ZT1 zetcDr!%uu%w7zPsJaN5Zz50RFR9aeg#!G{>wQT1}TF=dow6>(%#y+J;#(}Jh(pwsS zATo3k67e18tnXd+mO2XGk@P4gzO8@63is*q*Y_XZ*cX1~_g7fRK7aP(<*x`7&b*-a zFQ1K3|Mv119b5nK>BFCwFP`$3gW8jar+VGB?GTTdCWPr?lO%O_u`K3v8u z(q4=g>xJm|QC>)UA>)Ot7g_Jo@*=kv2YoQ`xBukD2StB$2C_Gl@uIBbd+z$t!?F49 zCyf+;weMtJjDef!r5?dWnQg3M534?xde+DJLdWN9I>53%PHokdvpy~tC9N!T>SG-} zWvwpyP;WD&3v?i15T{&5NzL!;oM)46x7^a3a*DrY%{04ya_VRt*h{i|vHCsizsCRh zVw5~|u;@qQxpq-TAi(SV@EjCh5H~xVC^+M`= z&AVMEj!~yhf5Rq9{W@RhrNS!lqUxWff4cq|`e*8&h3`DskJwMPpK3o7Kez?-C)-cq z?SP-IKf3D=n08NZrmVZ^x@m0FH`Q%B$+l4OLEJ&}crTO`5A;Cmi(+pm^F>V$x|p7N6vGp0hm^`_J|}F+uCh$U{n1 zSAjO{hVeB;mzTL@`Ylu>4X!mg&Z5PM3J*#y#$2h{0#f3AI1qo7b)Cz(6hHo{d!NIn z?IvZ$sv)^Ff40x0W7MD4!IE_`^AVD4g!nP8lx$7Yjdwn&Th-NJ+LJ}su$hbRu~vPr zZ$lSW9Xevq>vyzznRa*k7yCxP*Pnczuv<6V-(7ELz1Qrf?s~&dA_TF&tG$-&owuV2 z--oA3o*%$q2a2SAyZrd^&1)MZ7x)%2x#_BJ|6cL{+!3FofjP6eWupLR;^ zI?(rLxcdK4h(kwxPma49Zs{h;VM`lWXn@9w&dOyt^2uFN?G6^_?0Ww? z)>|Xl2+?nIj?q`QPwJC3*;{bYFtmNVPif6Y(avm5u%zBhn_{<^ zY!nffG>4%_9qrTbu0E|BsVUpHTe_~MbJ|-JGq36MgMv!^v;1it+KFIp>6$$JR@BvY z20C9lqw*}JFTDWbOL88H;r;1s7}SBReoZ!^{Ix*B$K$#4pS=5t`{O`)3=UKfm~X~- zjt2R-#e;Fnntcrc>SCyCUZ0p6^-f;05OrZ@Q?xb*XOyh>{on*y^#;m5(djKabQQ9x zYxOzA=dGr{NY198g{8DMj|O#&W7#ACxb z01M_dKzQzdOz_fU8~V9ygCkzn*&^epPd+g*D01%oZbbM6kYBMM!+%73)qVzk6sVQ- zNX(~zuiT919QpBaj|cOXJLfFg({IsGP1i_5L6TK&1vgn5T5aPa;@~Sqg^MHY=;fFmwNUW#~! zO~ri#hU~f|Az5LFc8;~`U=*a|X!$|3+3Os{b4BCyT8Rl;ozpLGHV8QhLbE~6oADfw z+O%7=*uB*4@6T~t<8%Gw@2`dB&f+)(rYjKJYSZFW2V&iy1db?%0DMMrEaXA#;Dl9d znlsE+9W;qQSsZzF|GsSnot`!swgK^=w|D?sZJt@aL>K!dIXuhNHYT#cTDA}AU=SXq z9YVHv)#Kr|>~PxN0}?lO4wyrSb#@=d3fn4Acxbqn{Yb!|Sl#oCqHAVpK5~~v zYu(?Hs-8x5v^KJ?nVmCF(Xi65ebxc*$Wd2PjTF;#sI1_x`L^{*4+M1^eU#zPYEY_y zq8;i|Xl$q3kwnzFyMA{d?`+!ceB$pG@OHGte77(bgOXCkLp~ceQy4V$)h9=gf=F{c0^cnLPhM+8#)O)Y3$^f|=*2C1*bY&l{ zZ%o!Okea6gbvQ_^J|~x%keG8gx%B=mf$@+{FFM8a{ksmXnilgqb>(yw*X2AsenWu) z#Rk5r5&q26GvXRzr>cY3X4G1Bxz8RUX^1zUk`jC>K)f|A*zITF2LLZcitrRJK0HN; z`|SiV;x^kj@M?b0A)@C8YktK2c0N(K0=>=L95*UXip-s^#eoEUk&)wT6@Wc5c7ynH zuAKFrV~#OWjM>p8GJy8Nj7*RhK_xQgV6fY2Z|#}-&66)8FQW&NgTFjpMNddcl0tm^_Uz~`(DIn`QhNU5h)kP z%}fs1qTt)S6RPIBbByI7P*vv<(K)5~l%ExdRtHo%JbD+;e(OW2hTM>buKyk88Hr>kvMK5fd{tir35d#bb_`D#Z@H+Fb z^edU3>eX}5tKB-W{=($sSijcXJ^IzVR;M_t@pn}7yC0!mxT=#^Bwlr@8I=*ktOjCr zDpr{O9b(u=fZMwY*Y~VmKQqK7lr48d*=soVS}prpGIQftEPeV@9*Qo`0$RlS8AWEV z<)CV3!MHB2Yn9=70F5hqv%+yla%IROaYVqX@OsReC8bG3n-Kmv8qhfbF|kd)t&Z6| zhF1i-vSm9ll4}vBBPuJ-nSfMbtQ5I|UKQ#_b|-7#=rzG(PIKY_r$?VHInFjrtcUFt zdE}cES@*)t>U5|}_~Z<)QD?PF@I-w3Ma$oYaBnPv*T`#jPmbOT=UxN3*Gl2n@}nE% z;%3xM)yLh8=M$wwyA3PDGN5{`GLBO6E_Mg%UHfBmIn3-;?{AGg4YUIH} z{LF12VPDqrv_;L4zSLCHh9tDj0N)bBJa*o#tp%Q}?&NQPp>)cnx(o@{kV=RwtR zwIYs$(SrKDb4|iIUT3KEv0NR?n3wTpblT^MktHiLSv+CF^zUBe4Di z+OosST^C&83%v|oXZ$oo?^zxL_)>E&Li#x|aQ7jVl!>O1NA}f62L7!C#5Vz7jHX2Btmk5n8d2%_Puv?oe-+?-6L@Y;I7-E9@(!k1 z79S3Kk@_Ask(azH;Rru$iniIgdZKq+u|W*1Wj|K~;7@X7KD-QM!6y5Je z(d{898p-c92NSzGOL!d~2y|4^9CgKM%do31k@_ewZ&4Jvn8h4Au{lLjmp>g-gY=U; zFvU1GrYWLc1m;T@xJLL|zz=oc*>vXcXLV9wu6813mk4HYxYHLyogT{cHB_gSo*Iu9 z*)4~r$Tr2fb?`=g=L01m2dqU;QP(Fn0Tg5J_W9W(Br;)&!lb;%eM!rV=*g(;VzA;zdou<3Z8op*W#A=VZIBn1+e zCEa6>JWfj*X?tp>YUt7JSP~JRcwHNFiiW^723T*E(Kbr$;yC~XL>NhoTxwqAW0chw zxrbe63&~=~GjtN4{`RU-kIulpmg#DsAMkiWgY(q9kgBu)Y~8ay`pJdq>&bca?;s{A zLC9%WGVmUs=&g*q`7|5im=u%c9NGYDl$8ymk6OTwFpT;z~=7@Rv#f&gSSLs$ANjHBsUrFSXCsJg8AV=!I z8j+YFT$f2OA8eTDy0SJwSV!R)y^xLT2im=UU>VwDkQ8x`9?q@|q!j1aEJ%ysY3w3K zjr9Fhata3W3U-LDoOxzm87~&whh!uxvLAdZX~@(uKAKaZ%rTlNqecOvp)H>~Ekjb` zC0(-sF5!qYiGfr@xv&vTDO3PV@ega{ z*lN8q#F{$H`wp@CjRDiJ_xQh!pFy4CDdE-kS;cr@kcO|NSlvkMUyXQeeH34tz<0H# zk>s(PWuqo3s~-1w5n_0JC57Qr_d1D3Q`KmJ-FeDbT=(zAi92Xe@gjzA#uv%Nt`m<8 zt%SljZA5VcVH$JtzQ!7nCXp880o?I*gbvO=v=!bDtMV;8l)>$ISQ+f=)VQPKC%hxX z>{XXJcEDF`s}j_rJ3_p9XNcFv!$Ld_Ad16%d>u*0d}c6_$Kq6FNFh7g7>WZbFMt9U z=>#$8eG6*#KyE3fV6fRdv;tws!%4Il=-e{B(4lc9$TtjhhQb>U^OjD!nX%0>W{T&8 z^mSo)rBg=A!l^>RN{hI(fXiKY$eHbi2vLxv@gvA-wdK_ieMc$&FzfT#!w&QiuIfRg zPPewSSzI`W&q83|^wfEXqEl)clc&@bQAHg^;goN&>crN$IOUFZE3f(tG*D$fu9Z;W z@aOd%V`YD=5p)_n8oli3iuBM;W4F@CInTB6Eerhg?ubsK>2CcT_^7T|jUjNyewzo0 z3h4B5lZ^yrmrmxQXb3C*rGd#&JFn!_6@kU1sWX`@|1veImdbl*AZ}wRZ6zc2vOXot zA8LH5_XI=9ly$lGtkJ>8Ln;^jm9&8#zLZ^5^SFxso zqcy!TM{o2Z#UD3?c!kNutTW+asa;2mv^<#`tVlpRML^-M*w7~7y{9>n7yNa2mbTlk zd&V4Vi%qib7Mp|=e~%iSk$TD+Wm+e4Lov&vL$n>OgmYkmTqgl>KC2#T>OrouCS?mv z0q%OKAUb}v=N}FL>UA2{wZIP}=J{j4n<`pmv08N! zuoZ+{r^P6zibTjZHswp|N!?3FP9Hmdn`|Ti@+-Spv_!C1*cnfig7w8oUURVUmltc8!$T# zr&O>0Lo++DU2HiAc71eNnHfevykOS&-gm z2k=mIQI=ncG-mdUm0ncDZHvXiGIch|UX#?=xOjZ3aHJOStv76H(mk1_z*NBrHg#j% zH#Un5gWlslG_RhI#^!E-NQ@J`vGq|bw7-^^4oVOOOj*WcjNO{|xK@u}S{Li0s9t09 z#Y((YRh&eTJtkXT(?X8aSZ=nMB&UdB{c(E!tosSJVN`O$^@)*DwOwaMoP0 z(1~T2>XkBX+CxTg@9<$i%wVCtU`&YLEt;&yS{wJHN9Kng^fjlv0pq@mrj6oJrFYNT zPWi02!j*h-YW>cTu0C@}hnA9@MI8FitSEwQ-5RzmC;DR|Ry0oTmY`8Qm~9I+Cx`xd zD6sF2A9Nm_4(m}0$B2xAQdR)uAq3POl81>Mycq@r=vDIZs-ERv5(PmAxlJ(D=ehS1(j@^%{LQ$b{2d zoOLE&Gh$(lEmtrIHCSv+ch!o%eTR;iZW!_Dx|;pnRtR0JD3n(CCF%)?GagY`&9(^H zK9jJY45^u%dEN9R8>@91e7EJNmh_q+-Sbxyf;~E_dMK&xtHooCuliGoc4Zd}PNJ;S z5bMiCQ^V>(!}SxIh6CbVY&xnm^dW_~>q?q@kw6B5vD7FYWQcXi_w3QuKqFUSGD$$r zWeeRh3X^TSD!eFzEqLge_eHHa8!xlVSD+KHe5*Rd6?Fv14!O-X*7PY|A(oB}*{{t~ zz;7q3FDumZK9-JNOXh^yh@O@KGJJ79Ad>IdKL{qVYg2)|U340b0PL?PHo74ydSDgo zIe2@~gmM}Wod!c+GVbYPw`_!djp>5y@MHIXpffL|Y*1q5({&9pys@gR9L12m-6(#Q zEgl1iyaMa5U0DLN!<&~;NZWnV($O**_+L@Aq@h0d6=J@fnHRKqj`=SFfU3Y;X@T?@4;-{qR^K4O$()pnF#PYVH z8jUq-v{GxhP+NRRQ4V@34=!L)7K0hl5M<1#7v`B;V#ow7k;G3o4Uu-6jRrQA(T}dg zb1Yu>JkoGlW3$Xkk2pt>ZteWKruDx@B&f51_oy|bLV1}7f?(OzzM^x19r!KnGG4H% zJ0Nkbo$~>SGXkfLUmwlwmO@v$^_A}}kU1B;kS&Ocr?Ccew>U+-d77-jnWts+JcZSr ztkAqwX1w$p$YX#}P5r_e$^nT;1eYC5utBCom?G@L$sRWpHhMctuz9C9JGoce3_~w9 zc}vgaV-Hh$V@x4V6?8b8PXi2+3w>pyAWS_H3B}89{entuGIq!&ekjGOagAKV=U^7L z2WbGIRi9x4D4jYohKCRZ%0Ii_A%iN0x*S5 zvS(te+|t{{)q%3{Yf=Mtfv#fn z4=cJ~Lgf?DBO}!h6N-2Qpm1Oc`|YOA`sJ!cLmNi+`0^Qj>bKd!9slNW#$)+U~jN4*~ zgWRn>^ttHdV|#UOi5zrY%&s^zF&s)}g++sUW0WC$&S)L?EKVP?dtGLz;fWU{${6hB z!Mwq8eMdReCq!r}=;-sZy?!(=PVfoh3#{4;CIV6d5Lqv_iYtb&q<}Zs=w20*aMl@u zQn?*=ZyMn*zHMG0IFIU){u$iv1ZWy|*u}~y<@tkxhXex8R3JGY3fuww4|2DZ(dT-V zkL}aB1Lk8@3Or5L_f{BM*J=_6%)yG}1HZ1_ZlJwrAp7N+f=*UYF6~h~0dYUYIZi3g zWb5E6%;4C8{aBC0rSmBRH@sP< zw9Sty7hCn0t(Fc+`7H6ZW8CI#l!MwEn~FtFRbh9Gb)M^NxVcoxOAEzI-wx)jJ(JWO z&-0UPZP5Fr%WK;Aex!g=NK=S^9pj55H3%t>bfz$Rn4S47KHCiHYZ(2 zyzyj#3YklE{wXMP&XS;m0)=~fe_i5zq#F1@#o30{!J755M3~=2SE_((XgKc#xmn?B zpuv`!s}tcnC_PO_t-6kmTuFR`Twvmu#zH;2hc0NXcd;VOk_wRrk5t#hLT*zA{gv2I z7LYuf+R+_IsM??!MX6kUOY4~(HeD#9;O?jD*f)dy?V{mp4Tl5X7*&pJpU{5j^}}S}^M*kq)-1o)yJL z6+{8Ie{1ahSSbFa9-HrgL3JZepO4XrI*ZfKCj8qovqC;=y4S&K0-v5#fzYk57?7Dp zL$A%df(}5}BeoXao-N_8f&Ui!%*kj=Hd)sK5qYBGVR39{Z^{6f-j|R9&2SfZLLqcz zDoxF{lDu>p1<<*Fu(l!X7(2A0Q%v#&1C-FQNYT=5P6iKp9I1Lj2C&Pj2tY(Sb5h+j zLmYgu>d!+sb^NDdMHQgnQ8)`YgU~N4bBz>+Jj}Q z9#I)+R}=VX>j*G!>ZpOpeN{@krzMy{H6%_CxJhRM|Jp;2;K4+D|2Wmgsu!P^9^wd= zt_367*~gY;9c#)1=hx3O#>Vq zpaJ|%5}(@|%|j0TwsM}d1E?Uzlpdefx0tAoM>%7K`he9B%ZR=ly@raA7+_&rvqbR) zA^=|dfO7?-oQ|qS?opMPQ_nJ*e=3@b?M68hk5r34LOCGhvYJIH)#8vrjs8JOQ#LZx zqr!%7LRmT+ueiV#NDQ9!&rUL8_-EOtQON9x z296q?`1Q1JpKY0Z$yXKWJ0Ble`6|bByyx&KxV6P)ohsyMEH<-9)EXiTDrN7*HQf6y zAQX1VEaw{MQ0?*>!SPX!nt{oSd1d+;^kdtDM7fYttFpF1)2gnOCDvTAXtFfg95G}j zz;`rppla5bv|cF2d9rVyO3)bDtfA1ClR~YlEW&2p8I+Md-(2-w;w(sAcPEb_FJluj zWLdJ3hZh;RxpA_UENp_q(IHvLqmxzE3bli-z&zwcAX_dTy^78SQDne`&iuW?j~aE0 z-rzApw${{3QR0kCNhC-~Sxn&^6t65qc9=vV&_73y42w?BmBM8@Z4D2iCKi<`OoRx8 z5mJ%+@- zX7|jHK=(n8wSdA}P027=iNIfHcZyjH+KcEu;LGAe?qOjOZG^iAU|4JP6oSidX#{|Kp*b#(xf)R`+{z`a9BgR zEAJS?giJbFr6rkSpj9`}2NF8A$QU%FSUs(&ux$8}QHw%BH-FP;D3ti5&RRT@N+4Q! zU3u!W)LN z6W78ABNUZE0dv_k3{>G0MZNUSh8CXpXPm?zA(y**EcR}4S;1|M*LHJ$jW!X9>c`k! zD&1O@d04stIT$}c3<{OSD4hUxFc<&pVbMY$9BhD_k`f<0ZulsSI}Ji|!FKh52Vd z{h-IT;ek0Oxl@WV>WFTB%w^kHLzhGI@2*=l2ULCHwPn~hF~6iG#ib{LXT|Y zGuyw}iY49N_2iCfMEbm^YvaV}8ij#-B z?52fEr?BDDX5~oz86M4y2t4vJiV<(NkGFeK#n%Dbq*0xiGOx#YC9~h%*-f@W*bb^6 zg0;lkVb;AZ@ze6n7AhJ0tA32e=RNMpcJX9o;D<7gbQ*#-mZ$r4@ohn9Y{(%v@n6u@ zkVz+0RGGF+g^r=xiGdZvd!zP*nin~X2q;&NZl&JDix9e6R~_mxXXyIRSR-jbAxFrDdiFv$F&m|eVaJi^$6H{b!DS@D%olGtSk$MXSTwrj z-n~`)fUBnKNcMbFg+p7&mgEL)H0rZ}8&xusKova^j0Dgf(uN zoqJs~S?A%E3We14=(h+Ql$x!rXsnOryo;e}D6+CY;)yVA<<&%pv+N@Zek} zrag&0b)T>X`Dk6%jev@}`p_HLd9Nt=w(bFlj?6I^a}`i3@|2Dp|C( z=S5MLQOBg<_z-w+_*$b#YDxB3q+*}RqX@9Fk=Th)N5?UQ*zF?SD|T7B68g~7&zIuYfrmWgrzX5Pu7ahT7;xyNX7HjaAE|Hs>4~wgP;Q% zK3yu-%Q?DUQyTh!A&oKGTw_hgt0cx(dRzb`tPG)(GtPBRQ|vjo+FA`4KafURVJmod z)pmj|g${aw$azC8sBEz^-CB)4lLBS+oy=Wnw}a>|+6tjMof0*0qi6t>O#AT5ZTyeH z=m=f%y#y}l!*p$Z%li3Eg%JorlS8-g9a)a{#%>Rx=orRzW81pdat0;I5??L$PJhP# zJSm|1rmg8`J3CnG>D|>>ybRJkrh(T3n%1YwJBKZNoZLHfG{H7$EERTw+5Ii8IcE_}iS+RM;#F0^0CFQY-7lAG=tZOfi7#YTQu zUXF79vS;+D677YYMitOg(%eaz$$4&z%(luX%{gw?jQ{j_>p^U`3FAtM>hO}jX3V$A z>aj1q>AK74F$XEAubg@E>{&W>kM+|fvB)b8!XA~gkIp0LvUZf8$EuwBKy->hSy69l zk#>Os4*L~!a# zTJ!qMEP3@&kA|nZlfn3AEcA|WWkB1#DKHuba)7QI`YaWFEws215EX-^u=LW^6k9Gj z8PmYKLyaIPceo~*AsgoGNMbDw$L4pS@cf9v%=5%MUm@y=h~`%F>BFr(AP%R_7H6J6 zW)VwYy2eO`8&dbYRf11E)gCKSnb97OzO3!_2TA76PVj^dGYKKeWc+WZifGu9Xgv+? zMaUj&NI$X@yJXE&{q4yZE4Fr6uM}}RIICcXgh;8~<6etF^w8XpJt?vZ0w$yf*+f25 z_CyL_0#P{72usZCqV_gF2NyRxvwp(Z%m zm1H7Bl|i05wZWDnO{=H9jcsdvrGFZ73UzzJus(CZ{xnig;aZqiX;E3X){o#r%(}Zt zJeRn|karjmM6h;rdue5O0iq9Im0jj_!s(Ws>57|E;o4guoe74?1HJggECU z#cYCe-J)B^tQAV&_>@U9Ggi9L#<4>zecs#;Qi#(NaeF1x%8d;h{aG-Wrtz~Ljncnq zhe+`b5wyZtiK93gHvjAC+$wTfa&@5HcDpg*cBU8;T+XdB;*+C+=kCzmaV+O#$JBIX z$$M)#>J#9gfb}IOl?QUYd?A<{cIPeWOHtn9wGmorJ3`)!s$)|~W|ubAAH7e-+qW$a zVSB=J7#ltaPQ%r__Iw)JhocdEc!L6|*fGU)B=GAi0fPHTGe$Xll6FNE>F^YCL8qTB~T&Es~0tBAsn1SiJ=|E6pAjMr7_%wZ0|}Dp(ANmcTPaz zCW{L?v$%&0B7{W@Isv&8J}IMMZyr03ryF8$nCy-(8eu1)}+k0x@S<-B*HjLr+ z+h87nhUq&t3VaW9EVNX6&PTV>M7%#-!B3m(8syE8uQh&R9E2WTZK_R%c(&XlC-o+b~Xz!sS+Z+J~g{4>vSjY;$nls^_}0>m$WzZ$jic*9Tyh z-MwtGTt|JRd{lK6qHtRQ zsjU8kt*w}~kp2(L@zI0jV_+Ga3=Fst$(Yz^*|iEZK>B^T^DnyOOW2j4u&S)UFn4A< z=MZ%5Fyzfnw}MFGs$?)6Yz3T9r1nZiOOr7LdZ(t19X}u)+&de-5@^|saoUPyUmr?P zt>E|31FGxWWv`Qng6GlKWz|Y>57IQ-$#tRcvJ*^-?en-%H>uOO19T)Ev+R{9Bq()z z4u71Yg!j1_`8`4MSId*1m22&m`OefAU(H&!#haqdLSuv;P%+ik?v2w_nlVKF(sgy& zm*7TK-HjC+lg;^v#oYp&I9+`5jy$?6ybkKTTJv?=4GeVaNX`Xvi>n&cr^XEZ?po+B zLF?GMZ-SApD%J(DF6dQ$R7R4clgQH>wI3w3+7MuCBswBxJ2UYuBIuPu;;Ta8FS5eZ zL{dp}LC-*navQ2)uJ(z*b+s726s(8sNdUZrrk{4nxjs8>6_9_yI(ck@?j03rn`|*p2$fI4J6w)r zTU+AQKNN0U0?C9)YxooE9F{UCXH>9?3#jxVt-2?A7puZKu^a5V#abk&!_S|lekeG; zlA^vhEWR~GeK#P+Ef&dXV3->)@im)o2xRB=yoJfOgcka zb9e1-no9J>!cj6v4y9jx*+k}!`)GZ;@CA_G6ntXyokTbBxJX=7(?;ccR+#2(JlNq! zE6gLK^0Rs{>ERP7kq1WM*SD^NY>KZE!zg7nj#n3q7R_@rqsW5IdO;Wpb_Z9o@1riB z^fde6$@l5{gF3745=Gk1L{!zqs1zpGAm)+VoUt->t906;(Z*m+06dd1@^ST=E&LqN z6x8R9#s=B$cF#SJ6HOw2wleoznTftP(EQ@4_AHvEV_f6E-P6svIU4mOZ9{6Abc{YN zhfM0jQ5=YZ4%v%v5Dp%9GPg3lYZQ92Mtm~%w=~}0*v0S!Wi`DIAC(&SMxs{~u1?{I zO7I>d<72qa5gjjewC5T_XF`yZ>NiOttFgkZ{Xx&ki6{pYBqT4QJ`g|YeT`MljR=Mj z_#~tc1k~7=9Cg9XCA?=6G(ZYX&u437q=`voQWovdOhS;;q@R^D!=WOaA~s09l|}ib z0+Ddz#(fuv@@5!$p1p8B_j=Df?jtA?ff4WFKYf~T_0{!r&HQ||#kAbPZa61knOEzD zb*{9?H(1qFYJI91Yk3kAl~+^=$Vuo^mQPhMcc0!`sxhO-BZ;%12a6u$>#~|obCd4- z+9r7g`evbEr3OphYwjp>&Q<3M#stW6**eCYD=9hkb`_afC8j~x@BXOo8B(1$smz>W zw|mG+exAnU4e*(;e7V|%S?*vJoJVJsZp0+arrl!~6u8_sdfL({fax0~ZE`Xn9Dv`J z{t1nP=}M7{CAq0#e2$kE%m;`KH+SJAD|v83-4)ZC)U8a3HldJKmvMsBz@n+%%X5Gk zkb<0mTzZ$YuLbz8)Jlsd?i<j!a6A2BBn{extu2S<}w$r0m(SzF>k2gyX8|iZml;hCc{$0GK5jvcleU8fSDVs z_+JD`okLE^AxiK5ulI8=w&xAUF+Kz3u$zq};L7|=^#0K_=cV%@?A^1Tyd4=wn=x{^ z(*76!XXn>avplpQ9auRan{BgGPwV#L3X*qAkl$TNFAu7uTRAsLL<|`O?h-;r1wJD| z;MVA)GvkE*%`DO7mQbFeETUq7JVFH3U_c@1H7PnUlx6quM8!5m&c5fSD8}<#ps6S3 z67|sBi`5eIG6x%^CoCd3To^`YMxhNT6aav}{Gg_>fX&C4#Wz~{va*BIgM^V&^%(2!(v|3&XS>XrFBRNyilZjl=t?Ax2& zu&Ey)48N_uKA@4*LP8xfW%3>5{!S^X&no%xfG#(p$aDou-joFV5)$Pw*c>~vwkzRl z3xw|%L#1gu1f_tTp`l&L{ydny@ACB80re>~Sh23Y6MIt-LL5P5zc^}%)DxY-Ji?15Nea>*~Bd~-HN|m^xI#?az9?TuFuDZ+O3a@hJ8)1+<7uU=J zd5^j=Jy@`t9xSlFW%>nFm&~4&ghCyg?3XU{P>GFIh8On`vSixVG3B%J$l{ zs(0hx-&daS2E;q5QV#X6VmFUN%aT2JI9Va_sJOZzQ0lZZ?L?%Wi0j9Scn&c^H>ykX zD7zuiTUFX_GYbPBsqQJXdllHhz|nKCaj6FQj8KI39)O0nbV5oSjh=ADh@Lp?1}7f)MRxe(9Q<_ zm*kR~7td-T^SQbPu|@$!T2beLVc830@J2~S>%@1U>ul-a`IVveUp(C|GEKU}zTX8w zaLyvSbhD_8YI9zpN?OJk-k*A~lK#junqP;4_2eJd6}y)pyG#_)a#R5nh36VqCM(9R zE*PHCL?0In2BIM>ng%SxY=Sr};OG+823avOBkx1^W@9$ULaZ?D`b3sgSRDZ&Z?SQ; zqi_@FSV~vGO+}7wY$VrBnn}U;4oE625C{~&i}%>BNDDW(gOdi;lk)qx2qLpc0hOo#VV69OGVlHdtKIAzylf|02(s= zgF;0O!_cQqg|MBDK8ja547q#>7n!F(cv;=7xW0|vqUoe;USYW|dwcve3lST18J>hs2};zc~(mHu1X$ z=2H(RDb?FCFz&ZU*1~Vhv({ZRMn*Mk?dy@A35_G-#<`{KU~xwOgv`Eywrwr5$)-lS zLMOb+oT6+2y%~_TH3$o^Id}!6t-FezIiqiE?B-`Zy12}*(4nrPRrl!L9a;^FtkDDp zeIhIuOqa;kzY{oUI=Th3Z!E~>AByFi<_Klj34V3tqL5tmxf@DYIeXT%K|aA$bnU}s zRE=&~q}DghUa6qhN1>-2q674EtXOy;u|k--bkWX{wbgwz=F)U90l%_83`rRsKRfZQ zh`<&&E8&{3>y2$#r$PzUhR3vjGlj960p4b9Sl<)H^K43YcuIG8de(#tfVJ#5V{Qhd zqz4YLf>*{VEwBv3Vn*|(`4jmz_>Q={j||3DTQVpJd^E&nedW61MdNfen zy5Zph0U^O?JAwpSA_w~OIFi_kO+k|8BY6y2(WW099^LR=NUpY3gznTMkx9T zpYLr!>SBKadWAcS38RKK*9_MajiD<0?IDWikTH)jr98q69RE)}8$8yA`GEp7C54nLifv)n=U-E<9uNS8u5%ro{Tu z#4xrwcHXY-Z6YCD)~04dkFoNQhm5|pFLiTiF)gUG>=+SlXP8)(Xm3}H6PduFbEY+) z7teLK!8hq-SjQ&QO`pE87A3u*48xo&{8rPTt!)0tB+vTeP1IDV#$CE!(=k ztg$+CeO)Gely^EPR?sv(i-mbRD!+7;-W;~n@E+#)edRX$>{E`K`^9DUkCZJ&mz9?^8svc$>os0lzPyZDp~yW;2%e{CbBNve#^p3Ug?+to6!E5|ZHZLIN~dglu*eFHC3M z@Vd_P`^J^r&)S^7N*#Cia2)T$uUp5x?gwY>jTXl5oYRm=;m{FaELeV9&vR^}gt7Bk zZEK6}N^wSgLjuu_koKG~o?fj{>@?(o2qvSrEE^MEadd#6*@)ib05m|$zn#KTzyK=G zHa$2RF(`=8C9+Mf)(Im){3EMZD^3^VPwPrnGK-~nzU_b-&azAO zdgo8H2Rew<)weSBwRH7rsWU0oJt^qR43b^Bk)~L4wMNH-QzT3C1^<9Jyg;r!Xo^Xv zifXt5#4O2W?Mn=l!b(qupWM0uq6!mQ-mRQzt`JpPR}}K3R;LT~190Xy(-KPF+_9|v zVY$T5jDoFL>&5a$W+BcjvgLo)1(h$YN7#Fc1F(u05RRV|qGjvX(74s5Lv1RhrB3l; z@7N-olN3geLW!AKA>Ihu%d>Q;9&4jxPnsFK=cV^MMgD7gUD=SO&HQ&$9iQJ(y6Uap zBekXXqcnClKm!{-V1WZ6dpjeg3nxZTPAw#?XWGdk9m!>PLkQdqsZfRvjh44G4-B~q zXB#sSicYdr&}m5cI#wt}Dsb-}Gak(H1Uyym^nw*)jTRB9IaBT570*Vr>MB0&E#le) z$EF(0O=P=z=9%h}500np6~-!lny5kL$(-21*dt?fZ!fBSN$5z83?7K#;0YdcC>kJZ zW~zYtf`-2mJVsDzBzkKK9OpxsQ}@vNAY3l1cJir>Q_@!Twv-CFQNzI~c;&ItwjN#X zF0E|CDwP$`JJ(z)8?M2LOcC}Zd?YRr9O;Dk9@&Fp1#WnR>pWaS!MWz16Wwdj+as0y zaWt7`nuMc@-DqE3VUb`LPIovg$;yCYKm>7vefT^S`|dg}MT&N!sY_FztX5;L&z4h~ zkTtSEDcd)i6_l-r$IMpvuL9aXS6J}9;r@pmK|~4nvx>V#AnB}0oZ~FFifg(xT}tnj zHGO(KVg4&%iF<&cfM(E*x)|HdA_!Qh&@8%u_uV~MX!^O*eCa9$!ME3}6P!q>6N)l< zp|0rwA}MTzgwD8&b5tbnGRxEe-d=B6&@l3%%c`Jj+tXl1r&73r*0P2(+6#ot*b9`G zhTGK_KeVS9c^W)%jmno7;zjpaCbPp6WO~mIn@P)+d@dm4T8+-yirqGwr1rVo7Xqj* zRa$66z4ep;@Rk1m0$vPHXQp2FV!$xKNYHOAJ<^xqmH@{c8T?*%W4fm2s?hXaDedQp ziFTR2L452OTC}650vDjUsz)B zQ*mkUR!nZ2>obF*?Q2&o!l5hNCbG`so>9?{{?)o)FOjO!0dlPQ6=oXSg~e7W%6v_4 z$raIG-ZtJju2&=94CGR%uzAvaEukki)??j($wt}gYy}_oa7`FfG89XYH}TXbBFPNq zPJ-7r$ZJYnG$`+e-PM0h&P?cIIfuu7sAdrae3B4!h=^wylUZ1frEMiQOR)1F8kIB zRRP$wv0QRINEBYld>N$MN-Y4=fVl%{o2x7PKs_X6yVVXNtw)cZ(8niVhEA=8vC)*5 zPuRgBI&h>@$Q2TWTQvEyTb{$|;ei#IqN7RLrVX^&Nvc&SuQN_i{t+^G=!a+og3FY& zhHsrqdmq2l@HgYMz3A6@OIKX$5HzVRf7I#NqI!aS#PE6<3 zmL#Uaq`~J%6w1}NF~YTj)Rm$|(9#l;OySt4(2WxXBwM>_(mA_oWNHi`Ifl?TDy$kE z?aGlNz8VrB6+_nIFPK3T>8#|U5$EKI@pdqWVo3+T3sDxuVO4in^1J!?TbPH4Wd)bgz2)~CA0>> z1DK&Ar3*i*puU@h^n=f)dP|$DuO*vw6tB~!F&I;~n#uhT9CdpH_v2YOASpodEkOMU z&L74KIm?UXU6AjT7q}ZFpU^`#m|Wwe1w@@3wS-K3jT>&X4Q_KiMY^YY2y^pRG*~XvDIo`=jIp-@hFQ!~%BGnhrPsEJ zkn!g_iw6jBC#&5G_|4he#s4dmfo%im#OS8pLh3>YvmDH7tmOdhsVxbrjmGGi-aJsK zQCPdEHc%7VH)GSYq#4sZG+3+sas~}L(yNxyFNOfKQB&JiBsh=c=8LcI6mV9x$Gc)g?1}Y8|RVrIwAAYY}!qd?b$~J}6QKz>|;wLAKdL|@t>7>4}qAY#l>4uDDu+I|_JFy=t{DnTf{rI*ppuaEiLU<9`p6NxAXLlS`VMzK z2d$!(9_$efon)H=u2yQ=@J0t{Gf@G=ml{a91m}Te;qIzOqYa8$qC|xbEn*0h{bf{T z)cU@|+cwu2!lW#;jaQ_%7xJ~4twJm7Dl~lm<^~##kN|EqyCke$iLLmg&$qVuWD2{A zXNv16ZYanoyvNW)kpUufmB-(}L8C^9qETb;XAcT{ZxTl96=Eevx5BapCFT0e&@#6Q zEQE6?^)^9w&6y<+i5lv`KTw>2XW==Np1f`Z^Mmf`2c1oG)Gpy{m_xik7u}M+N6n`P zvSjb1u6+IUbs)D}t*@?W3tU5PczhxYwR6u0d98qaMYk67iQ285x9mD4_9@lc82yKkCqg@k#@4SGE_dXi-oP(hL9ta0%QwWkpo=z zXG2s#RzOe0qA*1B8uK($Do)lA4W!9oJ=#&_#@&>*!QugOe66e%rOsFtI$UL<8VWh) z8e#-`2Pry@){fm+G!z{IU`5o=h{#5*p)^xz?E%PcXI2y3SpU}DA6xv-DW?;8?} zHL|Rwq&#+>up5!yR|yQUnmtibpa&FkUMoW?E>Z)%YY1OMG68*9Y&)?}O?S*Wm_kr0 zQ4nrMRP*!G<*!jI6^@W--4Z{+mTGrht+Opv?ltRb`Ps6M&2VaG#?pb|Fy=|JA@l&s z^jFP9p2M!iiPGJnBaDRXS7211_m%@gh57ypH9?~xs_abNTt`hS_?99gNpRT;dDTY@ zn&~D*d-4zoD7yv9Xf1Avb>wW9MJ2CEvHk$JmK{8Sx^2Hsdo2gRt_Nm5U5eD?8L;ld zdaQI_wu{KVw2264F`S7wnqJeOE-EQK)v z$`otse(9h#A$~osmSn4&bOCeQr{4utXu{|(D6~=5X(&^e^?`+m4ZN^kCKoOP+>#ih%uQI?o0BWq-<+cgLzI%SQu*6DyYPR z3k#C#JwZ`BM_;tAB`oh(xUc0Ba8lA>w?Ay$XeSK+_foDKCV?5GtJq_7qaLdXfB!1{ zebRcTy6<>!wUYihV0$zq8%NupnFhPmVUe6DRoqs9o@0X;!H%NyS(rd;yBIvn)>E%@ za0blB-M}KFvV5KvJpd52v*L~l`EIU17<;%*c7U+Vt|52Q`^V=nwB8IUTkvl}ZjsE~ z7V{_68>r)qJ`V8KDbSSBtV4{W^PI-T(*_F@_?IJ(i4ZXuCSm4v1KA*@{IS ze2M@i4c2`^P@+g@)E!ZE12xW`#OLQX#sLrQczR?C=oTB9Ju81kJiPZBjWV)c;z7*> zU36!oL!!>()5nxfzO}aNl(u3v(zCidB8udTOh64Qp^tCtEj$hv?Z|u8g_u>gu5<7z zcdx7WT}N+avfZ{y?#uHz z+f&Bokgba~w+!Nu6xJj9FMI7A4Gwt4dq)PsbS@uy$%U0SEFZApDp8GUMI*3XpM)pP@4&U!8q?-pa@&!nf__`C-6=!N9mg2AHBWMG^lDd)x~NOV1bt5+DhO}(s0 z57`UAOl|CrVL_8(iaZaB@*(O2=NSCr;ynw!g3(@pMA2}rB(kl$bG}Jo+6;;&!z)8B z7*Hg;GoOZ>?BuY_l#YQx2GfF+X5*(zLZA@^kPbRfr7Il>n7i!SOjy&y`>q3^b3By+ z!nE)#VUCwK+fi4)o?A%R^)1cUp}mD>4V8tXK^zP@OcI zRMu^=!sguKnx)le)RVgD@iFNsP+XE|lx_p|+tJ;EZB*Ab)S0SzoBDSVVm~^#zT-1D z#NH2VT)1^zzjvhmKogL&+?-<&7aDo{$9oL!5!u}nFKE;=60621$_ zbV<}4&=?;497`1FaTX)nIan#Wq`5o06M5#;sbKlKFs$*)*&yig=`2c_E&VCK2mBu_ zjd>>sir1+KQB{{7QwYba1^SCZwzg0tC?|7@$XLen7$tnj$ zIVv_~3+jce3u*E3h84ylc`LlN4u?Y~PB*1~9+^n{DVB4#Da!VqYqBdzW40MNU30FP z5_2*1ES9Ux!3U*>)T&$J(Jlqz7{27D6199?6J~4VTY!N-qv(TJF9ctqe5Z@vB>p^c zguc?|Y`R}XVMAzV>F7YNmTNW?OKR|+V7}E---iK<*&4V(F=^V3&$Ttl(XQDc*6kd9 zYZ=HX^3r`mkD8aGk7VQ}(~z+^B(7S9=Glf3WsKeYfCcfGixF6(PnOECy44V@)d)u{ zY(A?AS0^#EmtBX5v0`l@=?hmf6>;8zfUe&NW zO5}>kEN{_RJr(aaJ85haGi{{UtzM`eE|CL9V4dp`I!HIe_(T90VY_0|t%V22NokbJ zt%_UGR8oxFr?6s1-Ut^p_>j;Y7-rUZZhh;UyX(1{#($sQeRWnE&T}-&rC}`qSadPLhP?gZHa75-o=f^Tukaqw}dN5C$f<6a8V)0gIZeM5qT%(-T$pr$}rY3lMFa@Av zo_E@m4lnw^OMSBU59#yI0s{C$CUjVBlgwv{mt+`=j zuX=3?Q7c?YTjVVu$Q~ndyGUJ&JgBY|{SF#Vd!*Lej~NjoqLqB(``K&vJ;< zexdTfgp>lr+TxV0KYc`HA%)5wm}SM}(|zR;eJEi{2;vjp*AOh|Wnz8|x>teQR zZXR%5u$=i$JO=L1iT&2COAlE-@A4Qp#M*g(U4}=6pf@eUNN7Me#pk_ltxjRIGLx`j z65{{aajoK=u~@Fm8ugW91JtN+pX)EfKma8BG-!!6;6CY(MNwXJtc*Lgu(=8T9$tc( zrYoO3@||GhB|5h%kfRyi?7`H<1_m9wa)Pu@cI;k}L@AzHg&a;^4$2|(;_40GSmoeB z>5%)X8{C(5v0=!4^b4ML&z+~q&LdkQfP~ppoyZDYVa?{9+<=&+!k{j#Ugdqod_01hu2mxhCUhUg`Sc)su7bmEf`v`B_;V4 z9PV5u{gBrM5N^x}!%A4JqA5VE&NV=bkx>=R%A$&oLIC!V4-#z{xu_@n-*3qHu&pgdEG1QU5j8fhe(7> zbGt~wOMtm_`2233pw**GApSn1BGBn0Ir`QtO+9Si(5#@>Xq8P-G-{cyz6PZbrVn)b zHP6kV5}`YpA-pgZ8oJfm*);UDu-rp)7glMfv*f*};DA6sP^nLj=Leo%8N7GktVR&oe*@m`~bh&QMXM>nHs@~XwE7D;9&3qn`! z9xb36<4#L`icdz8OJj0nPAH&mvB(UcuN}Y_B*W-fs}Xb%+mj+|cD{m+S9EIH4X&e` zin%9Qbu!C!@@6qEPTypnBAc*)N-6CM^|zIDe%-p{6QV+Xe!3Fs-6lH)YgvTKtO-3) zTzVeIrgwbfbO_%fl zqi7%v>_LVGK*hO65TS(O67f*>R#R_jC_{$nS)x6mBU8--&}BfU{w>>-qzHRh_0nW# zA|pZ3XlrWFuNEU2dIz&KBbnd=im=7Gt!a(G{`OhMNY!uq1aXQIsS@}Llf8lt{60f= zpu?<0_{AU{V8479Nq}ek>%P>7?;~GuZiRig=KDRf2Ag%EOI}?s+LLk$x2)+pf;+Vv zr%pSmb~aW((p1P=)BC;CMF;G!!al48A8y9$b>aggR}55L?X%EbEXL&ofw;P^O1Dx$ z_HM8Wv9|75p|$S4PEaZ?o~pytCTqH=l*1j4*m~BZHas-f-D8+c}>HXrui&t*@ zU~tq41;+NIw8m|?ZVMZ3zHQ$=-CL`ro`?%>hK3Pp6aBuz1aVVBN^cGpcYdO_LMmPCCYy8UCRGFnxW@sMi*amOM$&G?k_#V|W9kBC_pBKKWTOm6wtvmg zc8KInWLwtJlSG$lioI)`w09fi;VY8&rkkCejaaPYwPbAOXc}gS#qbLa`ojwm4QsG) zZBNxZbsZMzVb&52djnlJ-Xa4PkEY|mIjYFAB&T-SWq|2JZK1NJ$fb=fS)-^V#qmM@ zDHeML7hd2->2k38464p(mE~$U!G${v0(lK&TG&h~ek(od*OG0%+~{Ww7{Q&#Dn+D@pHc2F3DWz3QzKo6Fzo%5oePe~>eL1uzk3QW-_oi+v=yl|KAd6O6%JXv1 ziR8_Zz7gHTrO_X?FjaZr1pIiz&BsBjcO$;JfLU3wh4f1$r4 zQ%7>~>7myWx>|0e*pu^|yUrJ2dV2A&pj#v8Ux=?H8g{1QiIxfMVd*Gd2`z}E8d^MT z&O!mahdXTn+qhA9ey%cHtQ}s=GPn05RM>M^GoiP5bP~vQRtlH%0#_QBip9yyzQQDG zbq#>TBk|I$&SEKC-P>pp-_-(MdIWJ}^j>Go2nu~%UEa#-7O;?|uF;)ms7$U^uOI`3 zh6cj%?Q4AxHiU`&y}JBomh`d)qk{muXsA7tPjdswqhr%)1KVl*G0Wd+|K0`?Q9$I*Qr z-UzTBU9ZJdRD2{Wbjq9Dpx=)MB;=~@Qdtk_HFHp=N}bS9FKa$49L z=)UPST#IJga)!H}2Cdk%TwW!wQM}d4I;AsHD`LV4sfeR5q&!-R*%Vq`ra@+Tam6>z zE3vH0(%W^03QJue5t-4?!S)Noc^-td;jj}-paWRXW1n!QZp1Z%($LxPjeAnwSW0?U zJaZ1CROT~_Q%$MQRGd;|Wt?dGMdAo=04}_J{csG2Q`hE3#%&t7uWVEKh&9-1 z3CXa|#Bo}JUOR&l8l-spoRMrb(4yoYz*Ctc3)8-0DrX+)Kt;4T#z7%K@~Ck-Nzq3u zq~^|v@T7Y(QctCXOj3e`^G0cUrC{;>(uGr2$vdT7u;0!?nt-tBmv^C^K5x5ijgBP; zfE(tI zD}{zk2XZKCb}I3~v1@usJ;6YXMau%qdnW}z@<^yZjZDJYF5-E@)y07wNFxMl(+DM= zHtZR-B{H{k4qTfK4J|c|sj6F5>YAH6iS*Dorw(E6$%+Mm@+Nc|I)Ue!b#0lMiM78Q zAVi^ide${L7@2d1!}<4`vAn6z?DH(;sa4GNjOAmpl(>CQy3I}{2zkF|W-0mWlZ;_* z%uH@O%P?BW5>M5ghwQRQ4<8#8#=FYS8kXIHX3{pEHiXrE@TK}r`z73XK}q#RaqH#O z4O|aq7@fkdfm-NbvsKl7vq4XOD;*~G$BBqhIJMrWDn7KW-fII%&8kdwEUoew-J zf;(x96fW7x&cy^JTJHTQj~_c}QuEw_XFXra70ID3JG7IgM+KzadG=!68B6rLsbLRn zd_bc?4sICQ!d4EE&jQ!Z?Nx-%Qu-KWyY;09O91FI8I9kNt@NIm|0^&7DO(P7g{27; z!>pIqYB8XZ`NUHMx^f|40^~}aZM{KIYpj8?(JWvp4+Bh`OvrCa`-!=Mq0_1k{qX-? zEvKy#G1E*v@rYq^PkBX*aJWcLr%7b!Db78UC@@K3Z<1%BvN?_*^wafQ)xfET=642L zB*$ffvuVhx(XPNwA3eXk%6f#}Xm8DQb*}tPIB*X{)%PL{JQYyio|Y`Ps%#(Rx&{OO zzM7dl#mkR^y=eL=dRuv+9f|wuSqc!FCN0w~res|nvMnv}qM5YHtn^RGDW@o`>vp6Z z-?tKK`^PM#hG%_J4iUx@V>=%re(b5QFhx;FKD`8ktP%9C&2*!evemdTFjuSV5ry@K zz|T8muWu-hKf(lMi(H?kEl;uK6CffcB!<)5i`;_e77_o6CrLS0mk*W~^tM?`Zf&T! zvyK}>3dI3T-IWz=Og6#Ris^K1QX|a7(+n^4y0sL+i>7VUhOfCxkgF6SN)8~okEp=b zuCZVYM;0mDcxo;-Pojmnk6m}IZI^kF(_L+_l@8)QY4jH8&Zzv21kTn1u)@-Cpy!TD zjRta(f`$bnW5Zu(^O_E`yiV(F{ND8M)OoM)rmdQI7M*-?ZA2o7VJywb1SC8g?Ad5V z9~U9C>Y2$h!qEG|JMoV^bd*NxE{(G4I++m7MpAd}72+P|I?EN{GWqHX5EXS<*V%Gh zV=i=0&vrhjBx;IIsjYC5%sTna8Rm>-%_!W#uq?0WL|1S>XC(CoJSIIzZ#sA=evE~` z!$3Bc=b`cX?)sciy6d*UBW#;x^vK;!6QJO&Ix}xhsAcJ#?XKP?#;#!m+X+iG(;vUn zp~zw5!-$yK^_LyHAgLL)o8V;i9l}jy>!R&NN4Sn$mL_yf~OZ` zNslW9jgg3@jsMB&u?AFmMnbuPm1ZaUy`!F|z4CRqk>yqe>Ca)&*&Lz+qRo@M$BKD7 zP+G>oB@l*g6hV(PbZnWFo@cPak)JtQ(JiC{3Wc#o(<89tA7{2XMR&E7e@+DyZT-#BMsR}FB|j`TiOk7hhkTu2`xF&a*G zPJ5#jhrIKchC_my`7MB;Ja<3R+?t#b10SohKL zT!#tg37ZkW3~FNQjk`H6P=Tmc_FWep%210jb)OnURF?nx7{W3#8b?tWq}2$N0xau? z{F^m^(E}wxQ7E zeqN1nDrL<&Thn|43bb-{OE&Cb?5eL0R)vU+6*YVtZ4puF3FJYLI-wmfTB8lzGAWVh z?uJyN(7t3f3l~|4pSmZi;7b$s#Gs%Kf31kkWI!r_xwf%f>C|8#atMtgIC^F1y}l6W zl@Vp)TtOs=3!P-{ex7BT27bhtUTfmvkJ1}SE>d8Z$$1l2Ua0hY)4;Cb!|M|Mn8d`^ zHY8wn25<=U1A}mVz^yF5{PsS?4@01j@9TMLAJDUa*8rxs zNTRQ!Qr1>NKkeIbfN<($e5#gOqi^TA%i3tTUBWJKlmJOX)WHJM2}{a)-*XCXwXJER z2Awgjd1as@FPIr@>&TwVy$cpbuMJ;x3$EHWVbBufBP zq{y#3uzU}t?c3XMZdVIAEowhdNOgnbgIwvs;ca_j!+ffuS~GiazC>wzT?}?==D|W{ z-<55{ysr8gPbIcZM>*TfsWeg(caOC@Ebp>(*1M2mQ`W(gwyz&QzPy@kj&M0TV$spI zqWRav>=yT$GaGkf7r+JzaXrnVu&FN|NeT9+wfuq}ML?~XbI{zm!&r^=C5e=+33V`R zxL%OA)6R1`1a3Ai8#SuLEXEHe``-wY!OM~~wx3?Ay5O~~BnAtCxw%%~F2 zV3uG5q^^s2-44`M7k;d^=Ky=9GvA~_k62o@*|^I$RGjfjIQS(4^GMR(^zK01kEwWn ztzh83Nngw2R&Suex8X7<$>RiZsK5Nh+}*aus@q$C}9Q9q0uQUf_6L-(Pt3rXmAg(@;gRW2Zs>q}M%o2&Ft^ z8`CX8)>hAiwPcC)B@;E-V7p>DcS99i=!p$PNs72&ofu|Up8m4&z6-}q?L6L_Z-L3Y}KP%qLfLdfu>xZYW+}IqwnkHV+)kpnsgt~kR zd(q9jp2*>x2i=&f7q;e>TaB8tspk&Kxz=nnHUx*tdhec+ps1<_AUJ0#kQ)JnIIPoV zxim?=mDlkE2~~ul)0=RbYetvu;j|!tYdzv9M$0Al?0vEj6$6ts#1@y-cdfXa86aqd z!W|{zu_V3OxfPO&`p5=k^fX<4nC?^FHB&Em32EEN4yu@1t9J!Jgv!*Q z5-<46vb@rEuwW{yK&8`C9$AQKZBrV6!D<>DHenZzTr%{Q$o%#VSj#GeqIN2Ros%9dnJ>j^IJjldZ^GZAU zSr&8lGKTm#^?u89vwtW%(*Y_qKrCo;X4EXzHa1{>*)7;tYc?)#&?SI>Ad`^YCwCRL zxx`oF6JLaEm;*yo)a&Jg*hT^0kEXTK+f{WxTJgx%TeuC#66-X>E|ewQTzP{X^k#T3 zLb|Fx;6aNIYl0Usn7mEa8Y`vs-d(b64>VJqbtS1zYAmy105*tRD7^!)ruk_{Gsruv zePDYIwgsRo>f(PG+8dP!3W2kC0@zj5dKn6CCHOr?X0f3zYm2Qd*}|2_^N`QHkr-Z@ zWB_I*hGMBMBDrBkCJ6->oU*~7A_U?2 zz2&Epp)LfV2c(~0<#RKbN9#Q)CmcG@CX}O*S=2m17Gi=l2}L7Du8WZOsx;eqU zB_ptZPfgUf6dfOKoW6@Bc&I5FA@VHwcweFDRemo*ImC^FmK};fr@4HOIE#LJay>aT z61Cb})oDB3dFYs2Z+DK8Fe!f+H0uJ4x;^d=VbpR8r85+Jl+Eg6_j+oLWK8w=AHK|o%3cuyg&Xe>XhNmD$j zC^tB3-cKQhO1p0}F==YtuIYZ2lT|S!OF;~iHVrt4=RA){RoH(Oia26Mge7dm04!7! z!OHN@asdX2k#N?V03$JWwJe>r=9yUBKbz1;srJ-+4ajoaZBUciGf2oUs7Yh##4Rbr z1DZ6&lZtYK`Q)_};^d+Trn)IGKq}T_xg>R)u^jAf`(aiCwaJsLK$fDA-H^Pu8ve-# z#bfg!^NCI3h{LHS;)X#Vk|J!@B7M1dOn@G7{J`_T9em)B*{$$X(c@{DdG$6j>qsod zG)m?*jtPGg2l#amH?}w0nxw^%-c(pOJDOcib<>vDu}R(l(H0+1PD61`uxRB0TyaC* zmZYGb%Y(&Ef-IKQL6)tdS992tKbAgRy!pK+6xpVG>y!zg{9VfT2!g3i95x*iEwocg z%6_Fy;L1cItwWqU#&2|*vD<;Q~%A+OVS&HlI$c~!T5prmrSV^#P2#@BHlz;#w&)?yhUKUv( z@^bce0c?3WbGX2$q#ZCnWdQ*2V{PPI!fg2t) zk~Yy<=uv?xdG$kUE7P@^R}D1%YMWS@9(uq@#N{`73FE1erMyVi%G|R zBGL5W>|>+#nx`?L0raYF{s>|*5>Pt`3Ru#EKVDFiw)s13KJ{ir07{L0c&RDT<6t{9 zlc$RQJ6uJHN=UVlxly5^uDCCVmq`w9*04hr)XkHA$;$1SE|j%51gH0giQU@Kf0jP; zgMI(ASJb6!Bv=qPe#{TLU|3&3KY?DVtJN2_BrI;Nq4G>?O&ua+o$j z>3q}dVM{fqI%^m56l-EwW@=z4Yi-o4^W^)Qde4A1eF%;!)Ko2l@j0}rzG+M0 zy4ol0s%}h^x_NUd$2(;0px)*t zoZxHj4S9J<4NMd`B3wKWxd#ZWUL=HucX8^fHzh7!Ud80Xqh0x5U+ z%P+X!u;e<)X;mBqIbd)f?{IGD5E1DqA82)X!8Oxf@n~|$HlyZ$9POHyq#j&JoNgNZ zJP)bxnDX!gRv!$;_Wy6QSnHEUM-OLVYFP-BKY<96>{L@2hW$-5B~h3v4z~oY5{~M@ zJ(W{FuZ0s!RWE0!)w-l}V-udaGMEGQ3`JHV2>}w9Qzy}@wY-@ibSWW{QNohs4fNnh z5Wt!~q{845B=6TI^EGT#tFqavZ8I5@otT45;QuZtIkbGiG3-MzT%Vl5U4!;F`QHsJ z*ZVdaTs=l{eVgCfo;1IDJmXUj0}@7O*mSN;m(CI$nU({)2iD27(r+C0oS|nNN#S88 z%3F>zr^O=PQrKwk5Vso#WpYO)*+P2*y$OjSA?%ZV!!5wD`^ zlar{IKo(0jyn_atLFr`uhHI(Z#}hAMIbvDrQG@Ek;i_bBbW4=%uHG!C9&j5*iPge3 ztNMrvxkJv?2jv@Ha?@xHu!Aolp^1iHE6fl<0VqZ*-Z*x5^SepzxxN3mq^@q(ZNqUp zC+#QXm30Nq+Qmq~FFUOGMqiC}Da2{Lf(9dbgaI+|R-M_lWDBdaubClH_fZTG-#-I= z9=7Q{(Mpdbr|m(R=2!PJ(}?}|-RPKDyqAsrC8Se=Z}u=aZh^p@C3{kd!Z}+HRb{o( zU~SZ$$F6J*PT3y2;C40x<;{(kI>_v=-zAfFkjt7y0g1H^ zZIbG@HVSD8O^uKHUlKtg-yW%=MmS2iMy@8#u1F_p-3R#wn?xNZkEK$m8u8f%c$}%BskR0+!#i zhZu8;TU95pM3*>Ql_Hz9xq_%9x2ZW~wXC(sTN&ENdd0OJml!5=G3pFSzLQ3a_Cb*2 zx?FyDN+zJBC&T+4Ud3P#=p0uLczIbk`r9vzKUTM2NP)HkM@k!Q_Dl#Vx~Z#0i0Uw6 zCGr-qmbLxm%GRJIv1I?fryUI6giRhENzgD}E~!@}pWQCeT2<469gd3O@bcH3)aKO% z8-RPu?lCu3pPY5$Lxld{b;Ng7nLaZiu`O}EOBY&LG`wGcvHIL2+Cxj4&VLE_GyHX}ahp2dvqo$lG0$hJt3?XE~Z z&b*^?N}sf}uhs8sv%Jl=Ig5x!$cE=O;x z0pNg{wMIrzaaI}2>7y7)ujC@h4el%C%}Xs&3`HwfWlQcSbx7pK=*Hrf$hVNNDS&iAQ=9W%i4VoqN2H=F z1{+CO;@;R(1*K{8S%gyw9wdT~vf>{+yv}S@vEHF=GBqioR0MIUn#6vYM)M z&;faTfT;D{H0wf&b;g(-3Dtuv+5=?f%-P~UiA3s*HY-;}I-US8JMs^ugez^jDBTT5iJxTruimK6yR~q|ZM0ulSE**!cDqhQLrQX8+~Ptu@FFS4McBHf37d0S z8=gW+QxX4118j18n^W2m;4xWaa3>yY6&fY%*Dl`Rh_lIoY__bV%^Z9=EAvb$voKG| z?A$@1;=QM?VOM7r+5oeXs!0wKj47^^kd+DJ#5I9m+5|Nz2RtRyjgF^H_*RhH{EoE z(3~)MMTtDg+tO18)=5s?RVC5$14|SiEG;nwnlBfz@ zK`tR6?Hr0ML9mP#kY!o9Eox1O;F}6k0Tta1w+0mLXXpyFsL`QRkI9u*q#e28KBSX& zt$jJ>d1H7R$R;`HC?*T}Q}nS3{MO}_SZ~9De7FvrR4OTE!N3eTFcK2wPQ#HRxVDJh z@bh%3WH8u60uGYW-2p^x4EGCAIqG4eiygp>F}xSir4VS_kwqmZS4KhAR?D=>5ggI_ z46S9@ak%^&Fj(@506X?W$LyCxbidqW#{>EUDx~LV9*RJ&k$(V2GD=Ze|1g zz%!z2iCrwU=U+YHk+V`X$GaIC1&Lt$P^g4M!{Cn-C|3F;|y z==)ao-z#fCcv`Rns#B94OU-}h&~BiD#@fVoCMecXKL(|UsaXP=3FF2Yl6epUb;OS> zC>SsBaNSb;m*lYQB*z(Xu?h5=8gE5X%>_I5Mw5U+Vmh4N(X3()h30CTjzfZD+hO1g zMF(w4Tx<+fDpw{tHC{_YEITzH$NY2^$`3$AyV@vNCykrohU&{?$lS=#rw@y4_`#pg zXayWr>@)81*1*86*IysflAwKbz4y@>7^kVhL1;S+6}#Q~z+mkOMr<~N1RSxFTW}@@ zMTnJ0a}NpV98DKvC~~*u&WEM7(z>0!N>DqD3=_b{-LPS@lk;-i+JKMirY4P?%ucMt zMjo=it&28cg0#6~As$d#4&k*?(sT-Hg(%PFy<}r#-&&GuO$A#UB5goLBKDvpA5N11 zWwUMO7a-PivZ-5DQmp0U(9$I-Qo13;AYAPpM2$66R9@kFsjH9IM<;!Ac_8qVT;-Xa zgDX0g?}*&Itz7x8$jxn8#Zs)gB?8b?qwhA{j~^KJ1O$gj1=p4L1C%bnSml;?HlM{&k?RUhs65Mm{CL(-^HK z5bxnY)Lf_$wkG>&N@+v$;fJT*#xf31T~ z;ZN<%Gl-5B@-jW9?UQYF$X(XcxIOsWFetL-!RC;^s2tS^pNcR<;`kW|(1yZO509Ab zG^xk2sxRI@kjiy#b;(sRp#<&E#&0${=hfFr2ZGZY^aG}l0b!B#FlQ#Xag8yFjR6@9 z=JQ0y3lHZl2T=Q}69OSkUmPk4f==;A*w9^e-W3>FjiUf;nw7$#PcEn}TfGZ2l|oTw zh$^-`91hN1MI&%%Qwo96D^DDw z55RtO1ctmzHCz@d?zE^G1qj6x%I0Ua%`cb6H@3=O<;wuVr&YO!SoEP;Rgx`ht z!t8;5>Ai(F!jz3NHf^|fm{IhjvRy+8V4~4h&8ZoLI?#mhTJ}IzhCjgeUk<;zo35Kf^48{1(=LaxKPAw> z@L(~i)`mxRPDsv6eWHKxnf0UWgb&*Rf0XCovG@U1)_caWukr=T?O_~ygS^_cpsjSk zzOr@3kzMs^Ft#P1^rEbYoZ*S1YuL&j7_+64t=kf^F<{tfX|a?TaN&6?w$}yLzWp9v z&8gZkCA(Lp)HEspY>`c_snlAHbsz_$fD@L&>6H+YiPfPef zJ@ls_=q?+?zRDK}w1+|L4T^0KWgVD)pnpyPU3GnvHW=9F@QOP;WTn7`LNuG`Oc>T@ zA&Nf3qbdQNv;DLkSq?#hc$20V3d_h2vrFZG;vESIOX*ej!CgoJBpCP`AK=XjM@u%} zXXw~&W$Q2+6mV-QT@$y|J@MBI1y=ImWLr5o(^aPci@sCVfKNRnV;aO+oNFO>3u>-B z5irR1RHMLn-7a4Wm9(960P#ZilT6UA8=?3)mPv8-&j3>DS8QNAZD3KokH~}!8gQ@$ z7NK~T{`oLEvX9mL*}Ch~EncZ1k7U>315UNKpP+)!d*l<~zL^;9G@~fYs7K%pSk{xg zO0qx;*6ffx#AH&GZcAvMVh2P`N6Tbs)Sh5 zmuw=+-hEmpW7C4BJcs#E*3^Svw!xH@SA8(o`fRQVDZ8~fLSHVR=Hb4qQZSa7c1bPM zYa;o$0}ty=^uE%}q<0p1FRVsY{Db6M$+=BwWfxYFoeu%89|ElZK;7TwdOh=N1902?I8H4~QI}H+}T)-8&!csB3c1IIP%s`Yp+u*dj zxe3$-w&x01BCzx`(vGD4qjuAP?Yj+1fXL48Tgj9!ltx^I+TYx7c#@6$t>M3;HaKu` zBhYkmGLY}0I)}XBy@8;oZ#Y12qLY&H_&7dg@h!CQkAH>0BiSN2b4Md4~QH> zpp_41d%r=xsnu;RDhLq)nM&}S$knT*%2Kp}9gtg%=|I_OO>dVhgmf1ssss28H_*&m z?Xxez;m!W4(r$e*u2xl_xM;He-ZAt6T?i-@R`wZ^M#Uxhh(YF*I!8@G9VhK&zfjHg z9Fu4lHMKh}lWb4?Q>E0_+}U9eDCNg?y?{KWQ0JJYfjBYP9bZbzVH2&* z(H-=*B@IvbQT&oVPaeF;jEgKBFBMB^9cK}#d2FptY+V*38tjM?d1puLX*-xh~w;FkU_*EK)vnHbQ}>A+`pxIkOi!T>?a(+YAA#Oh}s^jTr#@XrC?!HC`S$ zbSZ6;ksQpU)RI2{U z+4K}PpFyCvh`%+dfQz@*)Q&?ln>J_Edy-%kF||k)xX(<1sF>Ox)f0U@Nw`fnT@j%z z(_jR#u#e`9t&vumK6QYEQrnYr^OeW)9Jfk(Y>J(59+5@9%r5{V5ZQ5Uxz_o@pvz7} z9EKws^rXXNoM7~(oVn3`3rPOFd$;1Jw-C#KSkoBOwJfC4UeypJIIril83>rh?3+R{ zbTNBfr7mLzY;Z4%rN0|fgqd8`yh!%ETk%Osrs)cjWDu4sePZWG5S|XF0@YU3rv{>) zd{t;OtxhyH7y(&T4gpMn@iU9+4kZk}2?gHHeuku7a-kn|&cM(4yHl#9_>@%bBlGX8 zN<1bPtkS8C3k3i>*JGjsIj`Pq^%XhEmR5Ptpv!I9i0X$Pbv#nkKQQ#TEeHA@lAOnq zp9iH!?{^O@_`PLNcM^Cq0n!&$cUsEJ&bE^zN6(PM&_*<76UNIr57kk^S!0-EsiWsdpzL%{a^NW%q_#wE)ONbx zG2v=Ep%TyXPNq%z;@VE?i^FM2Qwv&lSdzQurI2r0K|9dnuY{(OTF$!gA5Q8NWab`1 zrA)+Nqg16X_Tm>^&3v-SLvtpEE(>X6BXbNT0>qF`7qKK~_p5lz8!!}`c8$4AyrRBi ze&n+yDSP4n2;>0zM8h=-0jC8<+DG-!>1Jk|Aixplc}hNHkpQI{f0o|@jY2>G&=-@T6LfGXo`#3*oQgZ7qB~Z0@r)V`d68Wd&mvdiUH@M)=b;+MEeZPfqnHJX{7H z|H)JIt!yvssEUwr$AufF~D@8to$ zKAu-I-8aAaMxO*bi_i0$ufKTn=E{}GqkZ$`f6Fue?ti}0Yq#h6_N(9i`Q{Gv{coTD z8GfGb{>EYx@Ex^xZ%IGJpPa|M4I6Q~vsQzyIAI@BZbN 0 + assert list(filter(lambda x: x.metric_name == "test-x-step", metrics)) + assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics)) + + # metrics -> eureka propagation + retry_with_backoff(verify_metrics) diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..713a6a3792 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import os + +import pytest + +from tests.integ.sagemaker.experiments.conftest import TAGS +from sagemaker.experiments._api_types import _TrialComponentStatusType +from sagemaker.experiments._utils import is_run_trial_component +from sagemaker.processing import FrameworkProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.s3 import S3Uploader +from sagemaker.xgboost import XGBoostModel +from tests.integ import DATA_DIR +from sagemaker.experiments._metrics import BATCH_SIZE +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.sklearn import SKLearn +from sagemaker.utils import retry_with_backoff, unique_name_from_base +from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources +from sagemaker.experiments.run import ( + RUN_NAME_BASE, + DELIMITER, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX + + +# when running integration tests locally modify this to your test account's execution role +EXECUTION_ROLE = "SageMakerRole" + + +@pytest.fixture +def artifact_file_path(tempdir): + file_contents = "test artifact file" + file_path = os.path.join(tempdir, "artifact_file.txt") + with open(file_path, "w") as foo_file: + foo_file.write(file_contents) + return file_path + + +artifact_name = unique_name_from_base("Test-Artifact") +file_artifact_name = f"File-Artifact-{name()}" +metric_name = "Test-Local-Init-Log-Metric" + + +def test_local_run_with_load(sagemaker_session, artifact_file_path): + exp_name = f"My-Local-Exp-{name()}" + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + # Run name is not provided, will create a new TC + with Run(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1: + run1_name = run1.run_name + assert RUN_NAME_BASE in run1_name + _local_run_log_behaviors( + artifact_file_path=artifact_file_path, + sagemaker_session=sagemaker_session, + ) + + def verify_load_run(): + with load_run( + experiment_name=exp_name, + run_name=run1_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.run_name == run1_name + assert ( + run2._trial_component.trial_component_name + == f"{run2.experiment_name}{DELIMITER}{run1_name}" + ) + _check_run_from_local_end_result( + sagemaker_session=sagemaker_session, tc=run2._trial_component + ) + + # Add retry to make sure metrics -> eureka propagation is consistent + retry_with_backoff(verify_load_run, 4) + + +def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session): + exp_name1 = f"my-two-local-exp1-{name()}" + exp_name2 = f"my-two-local-exp2-{name()}" + run_name = "test-run" + with cleanup_exp_resources( + exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session + ): + # Run name is not provided, will create a new TC + with Run( + experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session + ) as run1: + pass + with Run( + experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session + ) as run2: + pass + + assert run1.experiment_name != run2.experiment_name + assert run1.run_name == run2.run_name + assert ( + run1._trial_component.trial_component_name != run2._trial_component.trial_component_name + ) + assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}" + assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}" + + +@pytest.mark.parametrize( + "input_names", + [ + (f"my-local-exp-{name()}", "test-run", None), # both have delimiter - + ("my-test-1", "my-test-1", None), # exp_name equals run_name + ("my-test-3", "my-test-3-run", None), # is subset of run_name + ("x" * 59, "test-run", None), # long exp_name + ("test-exp", "y" * 59, None), # long run_name + ("e" * 59, "y" * 59, None), # long exp_name and run_name + ("my-test4", "test-run", "run-display-name-test"), # with supplied display name + ], +) +def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names): + exp_name, run_name, run_display_name = input_names + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + run_name=run_name, + run_display_name=run_display_name, + ) as run1: + assert not run1._experiment.tags + assert not run1._trial.tags + is_run_tc = is_run_trial_component( + trial_component_name=run1._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert is_run_tc + + with load_run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.experiment_name == exp_name + assert run2.run_name == run_name + assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}" + assert run2._trial_component.display_name in ( + run_display_name, + run2._trial_component.trial_component_name, + ) + + +_EXP_NAME_BASE_IN_SCRIPT = "job-exp-in-script" +_RUN_NAME_IN_SCRIPT = "job-run-in-script" + +_EXP_DIR = os.path.join(DATA_DIR, "experiment") +_ENTRY_POINT_PATH = os.path.join(_EXP_DIR, "train_job_script_for_run_clz.py") +_PYTHON_PROCESS_SCRIPT = "process_job_script_for_run_clz.py" +_TRANSFORM_MATERIALS = os.path.join(_EXP_DIR, "transform_job_materials") + +_RUN_INIT = "init" +_RUN_LOAD = "load" + + +def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, the same exp and run names are given in the Run constructor + # which will load the 1st Run TC in training job and log parameters + # and metrics there + # 3. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name == exp_name + assert run.run_name == _RUN_NAME_IN_SCRIPT + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.environment["CALL_RUN_LOAD_WITH_NO_NAME_ARGS"] = "True" + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + is_init=False, + has_extra_load=True, + ) + + +def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given + # in the Run constructor which will create a Run TC according to the run_name + # passed in there and ignore the exp config in the job + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the 2nd Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources( + exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session + ): + with Run( + experiment_name=exp_name2, + run_name=f"{_RUN_NAME_IN_SCRIPT}2", + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name != exp_name + assert run.run_name != _RUN_NAME_IN_SCRIPT + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + old_end_time=old_end_time, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. No Run TC created locally or specified in experiment config + # 2. In training job, Run is initialized + # which will create a Run TC according to the run_name passed in there + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, + sagemaker_session=sagemaker_session, + exp_name=exp_name, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_processing_job_and_override_default_exp_config( + sagemaker_session, dev_sdk_tar, run_obj +): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run + # but override the default experiment config in context of 2nd Run TC + # with the experiment config of the 1st Run TC + # 3. In the processing job script, load the 1st Run TC via the experiment config + # fetched from the job env + # 4. All data are logged in the Run TC either locally or in the processing job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + processor = FrameworkProcessor( + estimator_cls=PyTorch, + framework_version="1.10", + py_version="py38", + instance_count=1, + instance_type="ml.m5.xlarge", + role=EXECUTION_ROLE, + sagemaker_session=sagemaker_session, + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + + with run_obj: + # Override the default experiment_config in _RunContext of run_obj + # with the experiment_config of run + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + experiment_config=run.experiment_config, + ) + + assert run_obj.experiment_name != run.experiment_name + assert run_obj.run_name != run.run_name + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + with run_obj: + # Not to override the exp config and use the default one in the context + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + ) + + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. In the inference script running in a transform job, load the 1st Run TC + # via explicitly passing the experiment_name and run_name of the 1st Run TC + # TODO: once we're able to retrieve exp config from the transform job env, + # we should expand this test and add the load_run() without explicitly supplying the names + # 3. All data are logged in the Run TC either locally or in the transform job + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + xgboost_model = XGBoostModel( + sagemaker_session=sagemaker_session, + model_data=xgb_model_data_s3, + role=EXECUTION_ROLE, + entry_point="inference.py", + source_dir=_EXP_DIR, + framework_version=xgboost_latest_version, + env={ + "EXPERIMENT_NAME": run_obj.experiment_name, + "RUN_NAME": run_obj.run_name, + }, + ) + transformer = xgboost_model.transformer( + instance_count=1, + instance_type="ml.m5.4xlarge", + max_concurrent_transforms=5, + max_payload=1, + strategy="MultiRecord", + ) + uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "transform-test", + unique_name_from_base("json-data"), + ) + input_data = S3Uploader.upload( + os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session + ) + + with run_obj: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + transformer.transform( + data=input_data, + content_type="text/libsvm", + split_type="Line", + wait=True, + job_name=f"transform-job-{name()}", + ) + + _check_run_from_local_end_result( + tc=run_obj._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False) + + +def test_list(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + +def _generate_estimator(exp_name, sdk_tar, sagemaker_session): + return SKLearn( + framework_version="0.23-1", + entry_point=_ENTRY_POINT_PATH, + dependencies=[sdk_tar], + role=EXECUTION_ROLE, + instance_type="ml.m5.large", + instance_count=1, + volume_size=10, + max_run=900, + enable_sagemaker_metrics=True, + environment={ + "EXPERIMENT_NAME": exp_name, + "RUN_NAME": _RUN_NAME_IN_SCRIPT, + "RUN_OPERATION": _RUN_INIT, + }, + sagemaker_session=sagemaker_session, + ) + + +def _local_run_log_behaviors( + sagemaker_session, + artifact_file_path=None, + is_complete_log=True, +): + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameter("pa", 1.0) + run.log_parameter("pb", "p2-value") + run.log_parameters({"pc": 2.0, "pd": "p4-value"}) + + if is_complete_log: + run.log_file(file_path=artifact_file_path, name=file_artifact_name) + run.log_artifact(name=artifact_name, value="s3://Output") + run.log_artifact(name=artifact_name, value="s3://Input", is_output=False) + + for i in range(BATCH_SIZE): + run.log_metric(name=metric_name, value=i, step=i) + + +def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True): + assert tc.parameters == {"pa": 1.0, "pb": "p2-value", "pc": 2.0, "pd": "p4-value"} + + if not is_complete_log: + return + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}" + assert s3_prefix in tc.output_artifacts[file_artifact_name].value + assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type + assert "s3://Output" == tc.output_artifacts[artifact_name].value + assert not tc.output_artifacts[artifact_name].media_type + assert "s3://Input" == tc.input_artifacts[artifact_name].value + assert not tc.input_artifacts[artifact_name].media_type + + # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod + assert len(tc.metrics) > 0 + metric_summary = tc.metrics[0] + assert metric_summary.metric_name == metric_name + assert metric_summary.max == 9.0 + assert metric_summary.min == 0.0 + + +def _check_run_from_job_result(sagemaker_session, tc_name=None, is_init=True, has_extra_load=False): + def validate_tc_updated_in_init(): + assert tc.start_time + assert tc.end_time + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert tc.parameters["p1"] == 1.0 + assert tc.parameters["p2"] == 2.0 + # TODO: revert to assert len(tc.metrics) == 5 once + # backend fix hits prod + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + # metrics deletion is not supported at this point + # so its count would accumulate + assert metric_summary.count > 0 + assert metric_summary.min == 0.0 + assert metric_summary.max == 1.0 + + def validate_tc_updated_in_load(): + assert tc.parameters["p3"] == 3.0 + assert tc.parameters["p4"] == 4.0 + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "test-job-load-log-metric": + continue + assert metric_summary.last == 0.1 + assert metric_summary.max == 0.1 + assert metric_summary.min == 0.1 + if has_extra_load: + assert tc.parameters["p5"] == 5.0 + assert tc.parameters["p6"] == 6.0 + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + if is_init: + # Add retry since the load behavior is inconsistent sometimes + retry_with_backoff(validate_tc_updated_in_init, 4) + else: + retry_with_backoff(validate_tc_updated_in_load, 4) + + +def _check_tc_status_when_entering(trial_component): + assert isinstance(trial_component.start_time, datetime.datetime) + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + return trial_component.start_time + + +def _check_tc_status_when_exiting( + trial_component_name, sagemaker_session, init_start_time, old_end_time=None +): + tc = _TrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=sagemaker_session + ) + # There will be deviation (< 1s) caused by different TS precisions used in Backend and SDK + assert abs(tc.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(tc.end_time, datetime.datetime) + if old_end_time: + assert tc.end_time > old_end_time + return tc.end_time + + +def _check_tc_status_intermediate( + trial_component, sagemaker_session, init_start_time, old_end_time=None +): + tc_load = _TrialComponent.load( + trial_component_name=trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert abs(tc_load.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc_load.status.primary_status == _TrialComponentStatusType.InProgress.value + if not old_end_time: + assert not trial_component.end_time + return + assert isinstance(tc_load.end_time, datetime.datetime) + assert tc_load.end_time == old_end_time diff --git a/tests/integ/sagemaker/experiments/test_trial.py b/tests/integ/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..08f646c086 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging + +from sagemaker.experiments import trial +from src.sagemaker.utils import retry_with_backoff + + +def test_create_delete(trial_obj): + # Fixture creates / deletes, just ensure used at least once. + assert trial_obj.trial_name + + +def test_create_tags(trial_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_obj.tags + + +def test_save_load(trial_obj, sagemaker_session): + trial_obj.display_name = "foo" + trial_obj.save() + assert ( + "foo" + == trial._Trial.load( + trial_name=trial_obj.trial_name, + sagemaker_session=sagemaker_session, + ).display_name + ) + + +def test_add_remove_trial_component(trial_obj, trial_component_obj): + trial_obj.add_trial_component(trial_component_obj) + logging.info( + f"Added trial component {trial_component_obj.trial_component_name} to trial {trial_obj.trial_name}" + ) + + def validate_add(): + trial_components = list(trial_obj.list_trial_components()) + assert 1 == len( + trial_components + ), "Expected trial component to be included in trials list of TC" + + retry_with_backoff(validate_add) + + trial_obj.remove_trial_component(trial_component_obj) + logging.info( + f"Removed trial component {trial_component_obj.trial_component_name} from trial {trial_obj.trial_name}" + ) + + def validate_remove(): + trial_components = list(trial_obj.list_trial_components()) + assert 0 == len( + trial_components + ), "Expected trial component to be removed from trials list of TC" + + retry_with_backoff(validate_remove) diff --git a/tests/integ/sagemaker/experiments/test_trial_component.py b/tests/integ/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..3d79e41cc4 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import uuid + +from sagemaker.experiments._api_types import _TrialComponentStatusType +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments import _api_types, trial_component +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression + + +def test_create_delete(trial_component_obj): + # Fixture does create / delete, just need to ensure called at least once + assert trial_component_obj.trial_component_name + assert trial_component_obj.input_artifacts == {} + assert trial_component_obj.parameters == {} + assert trial_component_obj.output_artifacts == {} + + +def test_create_tags(trial_component_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_component_obj.tags + + +def test_delete_with_force_disassociate( + trial_component_with_force_disassociation_obj, sagemaker_session +): + assert trial_component_with_force_disassociation_obj.trial_component_name + trials = sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name + )["TrialSummaries"] + assert len(trials) == 3 + + +def test_save(trial_component_obj, sagemaker_session): + trial_component_obj.display_name = str(uuid.uuid4()) + trial_component_obj.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="Message" + ) + trial_component_obj.start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(days=1) + trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc) + trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1} + trial_component_obj.input_artifacts = { + "snizz": _api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"), + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"), + } + trial_component_obj.output_artifacts = { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"), + "fly2": _api_types.TrialComponentArtifact( + value="s3:/sky/far2", media_type="away/tomorrow2" + ), + } + trial_component_obj.parameters_to_remove = ["foo"] + trial_component_obj.input_artifacts_to_remove = ["snizz"] + trial_component_obj.output_artifacts_to_remove = ["fly2"] + + trial_component_obj.save() + + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + + assert trial_component_obj.trial_component_name == loaded.trial_component_name + assert trial_component_obj.status == loaded.status + + assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1) + assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1) + + assert loaded.parameters == {"whizz": 100.1} + assert loaded.input_artifacts == { + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2") + } + assert loaded.output_artifacts == { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow") + } + + +def test_load(trial_component_obj, sagemaker_session): + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert trial_component_obj.trial_component_arn == loaded.trial_component_arn + + +def test_list_sort(trial_components, sagemaker_session): + slack = datetime.timedelta(minutes=1) + now = datetime.datetime.now(datetime.timezone.utc) + trial_component_names = [tc.trial_component_name for tc in trial_components] + + for sort_order in ["Ascending", "Descending"]: + trial_component_names_listed = [ + s.trial_component_name + for s in trial_component._TrialComponent.list( + created_after=now - slack, + created_before=now + slack, + sort_by="CreationTime", + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + if s.trial_component_name in trial_component_names + ] + + if sort_order == "Descending": + trial_component_names_listed = trial_component_names_listed[::-1] + assert trial_component_names == trial_component_names_listed + assert trial_component_names # sanity test + + +def test_search(sagemaker_session): + trial_component_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for s in trial_component._TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + trial_component_names_searched.append(s.trial_component_name) + + assert len(trial_component_names_searched) > 0 + assert trial_component_names_searched # sanity test diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 3c416ffd36..abfe6f6d0d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -26,6 +26,7 @@ artifact, ) from sagemaker.model import ModelPackage +from sagemaker.utils import retry_with_backoff from tests.integ.sagemaker.workflow.test_workflow import ( test_end_to_end_pipeline_successful_execution, ) @@ -43,7 +44,7 @@ ) from sagemaker.lineage.lineage_trial_component import LineageTrialComponent -from tests.integ.sagemaker.lineage.helpers import name, names, retry +from tests.integ.sagemaker.lineage.helpers import name, names SLEEP_TIME_SECONDS = 1 SLEEP_TIME_TWO_SECONDS = 2 @@ -400,7 +401,7 @@ def model_obj(sagemaker_session): yield model time.sleep(SLEEP_TIME_SECONDS) - retry(lambda: model.delete(disassociate=True), num_attempts=4) + retry_with_backoff(lambda: model.delete(disassociate=True), num_attempts=4) @pytest.fixture diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..5548c63cff 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -15,7 +15,6 @@ import uuid from datetime import datetime -import time def name(): @@ -33,19 +32,6 @@ def names(): ] -def retry(callable, num_attempts=8): - assert num_attempts >= 1 - for i in range(num_attempts): - try: - return callable() - except Exception as ex: - if i == num_attempts - 1: - raise ex - print("Retrying", ex) - time.sleep(2**i) - assert False, "logic error in retry" - - def traverse_graph_back(start_arn, sagemaker_session): def visit(arn, visited: set): visited.add(arn) diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index c629fcdc30..1980b51da2 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -20,7 +20,7 @@ import pytest from sagemaker.lineage import artifact -from tests.integ.sagemaker.lineage.helpers import retry +from sagemaker.utils import retry_with_backoff def test_create_delete(artifact_obj): @@ -125,7 +125,7 @@ def validate(): assert len(trials) == 1 assert trial_obj.trial_name in trials - retry(validate, num_attempts=3) + retry_with_backoff(validate, num_attempts=3) def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): diff --git a/tests/integ/sagemaker/utilities/__init__.py b/tests/integ/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/utilities/test_search_expression.py b/tests/integ/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..ea7f4476bf --- /dev/null +++ b/tests/integ/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression, NestedFilter + + +def test_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed") +def test_nested_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + nested_filter = NestedFilter(property_name="TrialComponentName", filters=[search_filter]) + search_expression = SearchExpression(nested_filters=[nested_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +def test_sub_expression(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + sub_expression = SearchExpression(filters=[search_filter]) + search_expression = SearchExpression(sub_expressions=[sub_expression]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py index b9ff13c50e..28b537c1ea 100644 --- a/tests/integ/test_marketplace.py +++ b/tests/integ/test_marketplace.py @@ -23,6 +23,7 @@ import sagemaker import tests.integ +from tests.integ.utils import create_repository from sagemaker import AlgorithmEstimator, ModelPackage, Model from sagemaker.serializers import CSVSerializer from sagemaker.tuner import IntegerParameter, HyperparameterTuner @@ -33,7 +34,6 @@ from tests.integ.test_multidatamodel import ( _ecr_image_uri, _ecr_login, - _create_repository, _delete_repository, ) from tests.integ.retry import retries @@ -214,7 +214,7 @@ def iris_image(sagemaker_session): rm=True, ) image.tag(ecr_image, tag="latest") - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 78ba62c3db..d6c14037a7 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -19,8 +19,8 @@ import docker import numpy import pytest -from botocore.exceptions import ClientError +from tests.integ.utils import create_repository from sagemaker import utils from sagemaker.amazon.randomcutforest import RandomCutForest from sagemaker.deserializers import StringDeserializer @@ -59,7 +59,7 @@ def container_image(sagemaker_session): image.tag(ecr_image, tag="latest") # Create AWS ECR and push the local docker image to it - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): @@ -90,23 +90,6 @@ def _ecr_image_uri(sagemaker_session, algorithm_name): return "{}.dkr.{}/{}:latest".format(account_id, endpoint_data["hostname"], algorithm_name) -def _create_repository(ecr_client, repository_name): - """ - Creates an ECS Repository (ECR). When a new transform is being registered, - we'll need a repository to push the image (and composed model images) to - """ - try: - response = ecr_client.create_repository(repositoryName=repository_name) - return response["repository"]["repositoryUri"] - except ClientError as e: - # Handle when the repository already exists - if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): - response = ecr_client.describe_repositories(repositoryNames=[repository_name]) - return response["repositories"][0]["repositoryUri"] - else: - raise - - def _delete_repository(ecr_client, repository_name): """ Deletes an ECS Repository (ECR). After the integration test completes diff --git a/tests/integ/utils.py b/tests/integ/utils.py index 53440f96f5..d7891321f2 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -14,6 +14,8 @@ import logging from functools import wraps +from botocore.exceptions import ClientError + from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS from sagemaker.exceptions import CapacityError @@ -69,3 +71,21 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def create_repository(ecr_client, repository_name): + """Creates an ECS Repository (ECR). + + When a new transform is being registered, + we'll need a repository to push the image (and composed model images) to + """ + try: + response = ecr_client.create_repository(repositoryName=repository_name) + return response["repository"]["repositoryUri"] + except ClientError as e: + # Handle when the repository already exists + if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): + response = ecr_client.describe_repositories(repositoryNames=[repository_name]) + return response["repositories"][0]["repositoryUri"] + else: + raise diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..21fe49cc97 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import sagemaker + +from mock import Mock, PropertyMock + +_ROLE = "DummyRole" +_REGION = "us-west-2" +_DEFAULT_BUCKET = "my-bucket" + + +@pytest.fixture(scope="session") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="session") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=_ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=_REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="session") +def sagemaker_session(boto_session, client): + # ideally this would mock Session instead of instantiating it + # most unit tests do mock the session correctly + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=_DEFAULT_BUCKET, + sagemaker_metrics_client=client, + ) diff --git a/tests/unit/sagemaker/experiments/__init__.py b/tests/unit/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py new file mode 100644 index 0000000000..4d33ad759d --- /dev/null +++ b/tests/unit/sagemaker/experiments/conftest.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, MagicMock, Mock + +import pytest + +from sagemaker import Session +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import RUN_NAME_BASE +from sagemaker.experiments import Run +from tests.unit.sagemaker.experiments.helpers import ( + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, + TEST_EXP_NAME, +) + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = unittest.mock.Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(client): + return Session( + sagemaker_client=client, + ) + + +@pytest.fixture +def run_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.update_trial_component.return_value = {} + client.associate_trial_component.return_value = {} + with patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock( + return_value=_Experiment( + experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session + ) + ), + ): + with patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), + ): + with patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), + ): + run = Run( + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) + run._artifact_uploader = Mock() + run._lineage_artifact_tracker = Mock() + run._metrics_manager = Mock() + + assert run.run_name.startswith(RUN_NAME_BASE) + assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + + return run diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py new file mode 100644 index 0000000000..b7914010e5 --- /dev/null +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +TEST_EXP_NAME = "my-experiment" +TEST_RUN_NAME = "my-run" + + +def mock_tc_load_or_create_func( + trial_component_name, display_name=None, tags=None, sagemaker_session=None +): + tc = _TrialComponent( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return tc, True + + +def mock_trial_load_or_create_func( + experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None +): + return _Trial( + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py new file mode 100644 index 0000000000..8bb23db7b6 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_environment.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile +import unittest.mock + +import pytest + +from sagemaker.experiments import _environment +from sagemaker.utils import retry_with_backoff + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def training_job_env(): + old_value = os.environ.get("TRAINING_JOB_ARN") + os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe" + yield os.environ + del os.environ["TRAINING_JOB_ARN"] + if old_value: + os.environ["TRAINING_JOB_ARN"] = old_value + + +@pytest.fixture +def transform_job_env(): + old_value = os.environ.get("SAGEMAKER_BATCH") + os.environ["SAGEMAKER_BATCH"] = "true" + yield os.environ + del os.environ["SAGEMAKER_BATCH"] + if old_value: + os.environ["SAGEMAKER_BATCH"] = old_value + + +def test_processing_job_environment(tempdir): + config_path = os.path.join(tempdir, "config.json") + with open(config_path, "w") as f: + f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"})) + environment = _environment._RunEnvironment.load(processing_job_config_path=config_path) + + assert _environment._EnvironmentType.SageMakerProcessingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_training_job_environment(training_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTrainingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_transform_job_environment(transform_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type + # TODO: update if we figure out how to get source_arn from the transform job + assert not environment.source_arn + + +def test_no_environment(): + assert _environment._RunEnvironment.load() is None + + +def test_resolve_trial_component(training_job_env, sagemaker_session): + trial_component_name = "foo-bar" + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentName": trial_component_name}] + } + client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name} + environment = _environment._RunEnvironment.load() + tc = environment.get_trial_component(sagemaker_session) + + assert trial_component_name == tc.trial_component_name + client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name) + client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde") + + +@unittest.mock.patch("sagemaker.experiments._environment.retry_with_backoff") +def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_job_env): + mock_retry.side_effect = lambda func: retry_with_backoff(func, 2) + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = Exception("Failed test") + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None + + +def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session): + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py new file mode 100644 index 0000000000..b0ad55c27f --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -0,0 +1,306 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import unittest.mock +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import experiment +from sagemaker.experiments._api_types import TrialSummary + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_experiment.return_value = {"Description": "description-value"} + experiment_obj = experiment._Experiment.load( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + assert experiment_obj.description == "description-value" + + client.describe_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + tags = [{"Key": "foo", "Value": "bar"}] + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value", Tags=tags) + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.update_experiment.return_value = {} + obj.save() + client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar") + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.delete_experiment.return_value = {} + obj.delete() + client.delete_experiment.assert_called_with(ExperimentName="foo") + + +@patch("sagemaker.experiments.experiment._Experiment.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + exp_name = "exp_name" + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(exp_name, sagemaker_session) + + +@patch("sagemaker.experiments.experiment._Experiment.load") +@patch("sagemaker.experiments.experiment._Experiment.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + experiment_name=exp_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []} + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + assert list(experiment_obj.list_trials()) == [] + + +def test_list_trials_single(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj} + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary(name="trial-foo", creation_time=datetime_obj, last_modified_time=datetime_obj) + ] + + +def test_list_trials_two_values(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + {"Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + +def test_next_token(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session) + client = sagemaker_session.sagemaker_client + client.list_trials.side_effect = [ + { + "TrialSummaries": [ + { + "Name": "trial-foo-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "Name": "trial-foo-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ], + "NextToken": "foo", + }, + { + "TrialSummaries": [ + { + "Name": "trial-foo-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + } + ] + }, + ] + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + client.list_trials.assert_any_call(**{}) + client.list_trials.assert_any_call(NextToken="foo") + + +def test_list_trials_call_args(sagemaker_session): + client = sagemaker_session.sagemaker_client + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + client.list_trials.return_value = {} + assert [] == list( + experiment_obj.list_trials(created_after=created_after, created_before=created_before) + ) + client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after) + + +def test_delete_all_with_incorrect_action_name(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + with pytest.raises(ValueError) as err: + obj._delete_all(action="abc") + + assert "Must confirm with string '--force'" in str(err) + + +def test_delete_all(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + client.describe_trial.side_effect = [ + {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"}, + {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"}, + ] + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-4", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + ] + + client.describe_trial_component.side_effect = [ + {"TrialComponentName": "trial-component-1"}, + {"TrialComponentName": "trial-component-2"}, + {"TrialComponentName": "trial-component-3"}, + {"TrialComponentName": "trial-component-4"}, + ] + + client.delete_trial_component.return_value = {} + client.delete_trial.return_value = {} + client.delete_experiment.return_value = {} + + obj._delete_all(action="--force") + + client.delete_experiment.assert_called_with(ExperimentName="foo") + + delete_trial_expected_calls = [ + unittest.mock.call(TrialName="trial-1"), + unittest.mock.call(TrialName="trial-2"), + ] + assert delete_trial_expected_calls == client.delete_trial.mock_calls + + delete_trial_component_expected_calls = [ + unittest.mock.call(TrialComponentName="trial-component-1"), + unittest.mock.call(TrialComponentName="trial-component-2"), + unittest.mock.call(TrialComponentName="trial-component-3"), + unittest.mock.call(TrialComponentName="trial-component-4"), + ] + assert delete_trial_component_expected_calls == client.delete_trial_component.mock_calls + + +def test_delete_all_fail(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + sagemaker_session.sagemaker_client.list_trials.side_effect = Exception + with pytest.raises(Exception) as e: + obj._delete_all(action="--force") + + assert str(e.value) == "Failed to delete, please try again." diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py new file mode 100644 index 0000000000..a11f67389b --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_helper.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile + +from mock import Mock, PropertyMock, call +import pytest + +from src.sagemaker.experiments._helper import ( + _LineageArtifactTracker, + _ArtifactUploader, +) +from src.sagemaker.experiments._utils import resolve_artifact_name +from src.sagemaker.session import Session + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value="DummyRole") + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name="us-west-2") + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture +def sagemaker_session(client, boto_session): + return Session( + sagemaker_client=client, + boto_session=boto_session, + ) + + +@pytest.fixture +def lineage_artifact_tracker(sagemaker_session): + return _LineageArtifactTracker("test_trial_component_arn", sagemaker_session) + + +def test_lineage_artifact_tracker(lineage_artifact_tracker, sagemaker_session): + client = sagemaker_session.sagemaker_client + lineage_artifact_tracker.add_input_artifact( + "input_name", "input_source_uri", "input_etag", "text/plain" + ) + lineage_artifact_tracker.add_output_artifact( + "output_name", "output_source_uri", "output_etag", "text/plain" + ) + client.create_artifact.side_effect = [ + {"ArtifactArn": "created_arn_1"}, + {"ArtifactArn": "created_arn_2"}, + ] + + lineage_artifact_tracker.save() + + expected_calls = [ + call( + ArtifactName="input_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "input_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "input_etag"}], + }, + ), + call( + ArtifactName="output_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "output_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "output_etag"}], + }, + ), + ] + assert expected_calls == client.create_artifact.mock_calls + + expected_calls = [ + call( + SourceArn="created_arn_1", + DestinationArn="test_trial_component_arn", + AssociationType="ContributedTo", + ), + call( + SourceArn="test_trial_component_arn", + DestinationArn="created_arn_2", + AssociationType="Produced", + ), + ] + assert expected_calls == client.add_association.mock_calls + + +@pytest.fixture +def artifact_uploader(sagemaker_session): + return _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture +def tempdir(): + tmp_dir = tempfile.mkdtemp() + yield tmp_dir + shutil.rmtree(tmp_dir) + + +def test_artifact_uploader_init(artifact_uploader): + assert "trial_component_name" == artifact_uploader.trial_component_name + assert "artifact_bucket" == artifact_uploader.artifact_bucket + assert "artifact_prefix" == artifact_uploader.artifact_prefix + + +def test_artifact_uploader_upload_artifact_file_not_exists(tempdir, artifact_uploader): + not_exist_file = os.path.join(tempdir, "not.exists") + with pytest.raises(ValueError) as error: + artifact_uploader.upload_artifact(not_exist_file) + assert "does not exist or is not a file" in str(error) + + +def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + name = resolve_artifact_name(path) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + s3_uri, etag = artifact_uploader.upload_artifact(path) + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.upload_file.assert_called_with( + path, artifact_uploader.artifact_bucket, expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri + + +def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader): + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + artifact_name = "my-artifact" + artifact_object = {"key": "value"} + file_extension = ".csv" + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + name = artifact_name + file_extension + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.put_object.assert_called_with( + Body=json.dumps(artifact_object), Bucket=artifact_uploader.artifact_bucket, Key=expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py new file mode 100644 index 0000000000..21556f70fd --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_metrics.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import tempfile +import shutil +import datetime +import dateutil +import json +import time + +from sagemaker.experiments._metrics import ( + _RawMetricData, + _SageMakerFileMetricsWriter, + SageMakerMetricsWriterException, +) + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def filepath(tempdir): + return os.path.join(tempdir, "foo.json") + + +@pytest.fixture +def timestamp(): + return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1) + + +def test_raw_metric_data_utc_timestamp(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_utc_(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_aware_timestamp(): + aware_datetime = datetime.datetime.now(dateutil.tz.gettz("America/Chicago")) + assert aware_datetime.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=aware_datetime) + assert (aware_datetime - aware_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_naive_timestamp(): + naive_datetime = datetime.datetime.now() + assert naive_datetime.tzinfo is None + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=naive_datetime) + local_datetime = naive_datetime.replace(tzinfo=dateutil.tz.tzlocal()) + assert (local_datetime - local_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_number_timestamp(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now) + assert time_now == metric.Timestamp + + +def test_raw_metric_data_request_item(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now, step=10) + expected = { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": str(int(time_now)), + "Step": 10, + } + assert expected == metric.to_raw_metric_data() + + +def test_raw_metric_data_invalid_timestamp(): + with pytest.raises(ValueError) as error1: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() - 2000000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error1) + + with pytest.raises(ValueError) as error2: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error2) + + +def test_file_metrics_writer_log_metric(timestamp, filepath): + now = datetime.datetime.now(datetime.timezone.utc) + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.log_metric(metric_name="foo", value=2.0, step=1) + writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp) + writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2) + writer.close() + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines] + + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + assert (now.timestamp() - entry_one["Timestamp"]) < 1 + assert "Step" not in entry_one + + assert 1 == entry_two["Step"] + assert timestamp.timestamp() == entry_three["Timestamp"] + assert 2 == entry_four["Step"] + + +def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + + writer.log_metric(metric_name="foo", value=1.0) + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one] = [json.loads(line) for line in lines] + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + + writer.log_metric(metric_name="bar", value=2.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two] = [json.loads(line) for line in lines] + assert "bar" == entry_two["MetricName"] + assert 2.0 == entry_two["Value"] + + writer.log_metric(metric_name="biz", value=3.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines] + assert "biz" == entry_three["MetricName"] + assert 3.0 == entry_three["Value"] + + writer.close() + + +def test_file_metrics_writer_context_manager(timestamp, filepath): + with _SageMakerFileMetricsWriter(filepath) as writer: + writer.log_metric("foo", value=1.0, timestamp=timestamp) + entry = json.loads(open(filepath, "r").read().strip()) + assert { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": timestamp.timestamp(), + }.items() <= entry.items() + + +def test_file_metrics_writer_fail_write_on_close(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.close() + with pytest.raises(SageMakerMetricsWriterException): + writer.log_metric(metric_name="foo", value=1.0) + + +def test_file_metrics_writer_no_write(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.close() + assert not os.path.exists(filepath) diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..0e4ebee181 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -0,0 +1,941 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest +from math import inf, nan +from unittest.mock import patch, Mock, MagicMock + +import dateutil +import pytest + +from sagemaker.experiments import _environment, SortOrderType +from sagemaker.experiments._api_types import ( + TrialComponentArtifact, + TrialComponentSummary, + TrialComponentStatus, + _TrialComponentStatusType, + TrialComponentSearchResult, +) +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import ( + TRIAL_NAME_TEMPLATE, + MAX_RUN_TC_ARTIFACTS_LEN, + MAX_NAME_LEN_IN_BACKEND, + EXPERIMENT_NAME, + RUN_NAME, + TRIAL_NAME, + DELIMITER, + RUN_TC_TAG, + SortByType, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent +from tests.unit.sagemaker.experiments.helpers import ( + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, + TEST_EXP_NAME, + TEST_RUN_NAME, +) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save") +def test_run_init(mock_tc_save, sagemaker_session): + with Run( + experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session + ) as run_obj: + assert not run_obj._in_load + assert not run_obj._inside_load_context + assert run_obj._inside_init_context + assert not run_obj._trial_component.parameters + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial.trial_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._experiment.experiment_name == TEST_EXP_NAME + assert run_obj.experiment_config == { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: run_obj.run_group_name, + RUN_NAME: expected_tc_name, + } + + # trail_component.save is called when entering/ exiting the with block + mock_tc_save.assert_called() + + +def test_run_init_name_length_exceed_limit(sagemaker_session): + invalid_name = "x" * MAX_NAME_LEN_IN_BACKEND + + # experiment_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=invalid_name, + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ) + + assert ( + f"The experiment_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" + in str(err) + ) + + # run_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=TEST_EXP_NAME, + run_name=invalid_name, + sagemaker_session=sagemaker_session, + ) + + assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str( + err + ) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-train-job" + rv = Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + with load_run(sagemaker_session=sagemaker_session) as run_obj: + assert run_obj._in_load + assert not run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj._trial + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._experiment + assert run_obj.experiment_config == exp_config + + client.describe_training_job.assert_called_once_with(TrainingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg( + mock_run_env, sagemaker_session +): + rv = Mock() + rv.source_arn = "arn:1234/my-train-job" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + # No Run object is created else where + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + } + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err) + + +def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session): + with run_obj: + with load_run(sagemaker_session=sagemaker_session) as run: + assert run_obj == run + + +def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session): + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + # experiment_name is given but is not supplied along with the run_name so it's ignored. + with pytest.raises(RuntimeError) as err: + with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_run_load_with_run_name_and_exp_name(sagemaker_session): + with load_run( + run_name=TEST_RUN_NAME, + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) as run_obj: + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + expected_exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial + assert run_obj._experiment + assert run_obj.experiment_config == expected_exp_config + + +def test_run_load_with_run_name_but_no_exp_name(sagemaker_session): + with pytest.raises(ValueError) as err: + with load_run( + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ): + pass + + assert "Invalid input: experiment_name is missing" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-process-job" + rv = unittest.mock.Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerProcessingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_processing_job.return_value = { + "ProcessingJobName": "process-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + + with load_run(sagemaker_session=sagemaker_session): + pass + + client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): + # TODO: update this test once figure out how to get source_arn from transform job + rv = unittest.mock.Mock() + rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob + rv.source_arn = "" + mock_run_env.load.return_value = rv + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert ( + "loading experiment config from transform job environment is not currently supported" + ) in str(err) + + +def test_log_parameter_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameter("foo", "bar") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameter(run_obj): + with run_obj: + run_obj.log_parameter("foo", "bar") + assert run_obj._trial_component.parameters["foo"] == "bar" + run_obj.log_parameter("whizz", 1) + assert run_obj._trial_component.parameters["whizz"] == 1 + + +def test_log_parameter_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_parameter("key", nan) + assert "key" not in run_obj._trial_component.parameters + + +def test_log_parameters_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameters(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_parameters_skip_invalid_values(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_input_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert run_obj._trial_component.input_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_output_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text") + assert run_obj._trial_component.output_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_metric_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_metric(name="foo", value=1.0, step=1) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_metric(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj.log_metric(name="foo", value=1.0, step=1, timestamp=now) + run_obj._metrics_manager.log_metric.assert_called_with( + metric_name="foo", value=1.0, step=1, timestamp=now + ) + + +def test_log_metric_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_metric(None, nan, None, None) + assert not run_obj._metrics_manager.log_metric.called + + +def test_log_metric_attribute_error(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj._metrics_manager.log_metric.side_effect = AttributeError + + with pytest.raises(AttributeError): + run_obj.log_metric("foo", 1.0, 1, now) + + +def test_log_output_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type + + run_obj.log_file("foo.txt") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.output_artifacts + assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type + + +def test_log_input_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type + + run_obj.log_file("foo.txt", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.input_artifacts + assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type + + +def test_log_multiple_inputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.input_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang", False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_outputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.output_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_multiple_input_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file( + file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False + ) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an output artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=True) + + # log an extra input artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_output_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index)) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an input artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + + # log an extra output artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_precision_recall_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + with pytest.raises(RuntimeError) as err: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_precision_recall(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + + expected_data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": [0.5, 0.3333333333333333, 0.5, 0.0, 1.0], + "recall": [1.0, 0.5, 0.5, 0.0, 0.0], + "averagePrecisionScore": 0.5, + "noSkill": 0.1, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + title, expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name=title, + source_uri="s3uri_value", + etag="etag_value", + artifact_type="PrecisionRecallCurve", + ) + + +def test_log_precision_recall_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + no_skill = 0.1 + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_precision_recall( + y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False + ) + assert "Lengths mismatch between true labels and predicted probabilities" in str(error) + + +def test_log_confusion_matrix_outside_run_context(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + with pytest.raises(RuntimeError) as err: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_confusion_matrix(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + + expected_data = { + "type": "ConfusionMatrix", + "version": 0, + "title": "TestConfusionMatrix", + "confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]], + } + + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestConfusionMatrix", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_output_artifact.assert_called_with( + name="TestConfusionMatrix", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ConfusionMatrix", + ) + + +def test_log_confusion_matrix_invalid_input(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "Lengths mismatch between true labels and predicted labels" in str(error) + + +def test_log_roc_curve_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + + with pytest.raises(RuntimeError) as err: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_roc_curve(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + with run_obj: + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + + expected_data = { + "type": "ROCCurve", + "version": 0, + "title": "TestROCCurve", + "falsePositiveRate": [0.0, 0.0, 0.5, 0.5, 1.0], + "truePositiveRate": [0.0, 0.5, 0.5, 1.0, 1.0], + "areaUnderCurve": 0.75, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestROCCurve", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name="TestROCCurve", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ROCCurve", + ) + + +def test_log_roc_curve_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "Lengths mismatch between true labels and predicted scores" in str(error) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent.list") +@patch("sagemaker.experiments.run._TrialComponent.search") +def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + tc_list_len = 20 + tc_list_len_half = int(tc_list_len / 2) + mock_tc_search.side_effect = [ + [ + TrialComponentSearchResult( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + tags=[RUN_TC_TAG] if i < tc_list_len_half else None, + ) + ] + for i in range(tc_list_len) + ] + mock_tc_list.return_value = [ + TrialComponentSummary( + trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(tc_list_len) + ] + mock_tc_load.side_effect = [ + ( + _TrialComponent( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ), + True, + ) + for i in range(tc_list_len_half) + ] + + run_list = list_runs( + experiment_name=TEST_EXP_NAME, + sort_by=SortByType.CREATION_TIME, + sort_order=SortOrderType.ASCENDING, + sagemaker_session=sagemaker_session, + ) + + mock_tc_list.assert_called_once_with( + experiment_name=TEST_EXP_NAME, + created_before=None, + created_after=None, + sort_by="CreationTime", + sort_order="Ascending", + sagemaker_session=sagemaker_session, + max_results=None, + next_token=None, + ) + assert len(run_list) == tc_list_len_half + for i in range(tc_list_len_half): + run = run_list[i] + assert run.experiment_name == TEST_EXP_NAME + assert run.run_name == "a" + str(i) + assert run._experiment + assert run._trial + assert isinstance(run._trial_component, _TrialComponent) + assert run._trial_component.trial_component_name == Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ) + assert run._in_load is False + assert run._inside_load_context is False + assert run._inside_init_context is False + assert run._artifact_uploader + assert run._lineage_artifact_tracker + assert run._metrics_manager + + +@patch("sagemaker.experiments.run._TrialComponent.list") +def test_list_empty(mock_tc_list, sagemaker_session): + mock_tc_list.return_value = [] + assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +def test_enter_exit_locally(mock_load_tc, sagemaker_session, run_obj): + mock_load_tc.return_value = run_obj._trial_component, False + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + _verify_tc_status_before_enter_init(run_obj._trial_component) + + with run_obj: + _verify_tc_status_when_entering(run_obj._trial_component) + init_start_time = run_obj._trial_component.start_time + + with load_run(sagemaker_session=sagemaker_session): + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + old_end_time=old_end_time, + ) + + # Re-load to verify: + # 1. if it works when load_run and with are not in one line + # 2. if re-entering the load will change the "Completed" TC status + # to "InProgress" + # 3. when exiting the load, the end_time and status will be overridden again + run_load = load_run( + experiment_name=run_obj.experiment_name, + run_name=run_obj.run_name, + sagemaker_session=sagemaker_session, + ) + with run_load: + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + has_completed=True, + ) + _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, old_end_time=old_end_time + ) + + +def test_exit_fail(sagemaker_session, run_obj): + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + try: + with run_obj: + raise ValueError("Foo") + except ValueError: + pass + + assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value + assert run_obj._trial_component.status.message + assert isinstance(run_obj._trial_component.end_time, datetime.datetime) + + +@pytest.mark.parametrize( + "metric_value", + [1.3, "nan", "inf", "-inf", None], +) +def test_is_input_valid(run_obj, metric_value): + assert run_obj._is_input_valid("metric", "Name", metric_value) + + +@pytest.mark.parametrize( + "metric_value", + [nan, inf, -inf], +) +def test_is_input_valid_false(run_obj, metric_value): + assert not run_obj._is_input_valid("parameter", "Name", metric_value) + + +def test_generate_trial_name(): + base_name = "x" * MAX_NAME_LEN_IN_BACKEND + trial_name = Run._generate_trial_name(base_name=base_name) + assert len(trial_name) <= MAX_NAME_LEN_IN_BACKEND + + +def test_append_run_tc_label_to_tags(): + expected_tc_tag = RUN_TC_TAG + + tags = None + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [{"Key": "foo", "Value": "bar"}] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 2 + assert expected_tc_tag in ret + + +def _verify_tc_status_before_enter_init(trial_component): + assert not trial_component.start_time + assert not trial_component.end_time + assert not trial_component.status + + +def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False): + if not init_start_time: + assert isinstance(trial_component.start_time, datetime.datetime) + now = datetime.datetime.now(dateutil.tz.tzlocal()) + assert (now.timestamp() - trial_component.start_time.timestamp()) < 1 + else: + assert trial_component.start_time == init_start_time + + if not has_completed: + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + + +def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None): + assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(trial_component.start_time, datetime.datetime) + assert isinstance(trial_component.end_time, datetime.datetime) + if old_end_time: + assert trial_component.end_time > old_end_time + return trial_component.end_time diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py new file mode 100644 index 0000000000..7e068136a1 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -0,0 +1,191 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +import pytest + +from sagemaker.estimator import Estimator, _TrainingJob +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import _RunContext +from sagemaker.experiments import load_run, Run +from sagemaker.experiments.trial import _Trial +from tests.unit.sagemaker.experiments.helpers import ( + TEST_EXP_NAME, + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, +) + +_bucket = "my-bucket" +_train_input_path = f"s3://{_bucket}/data.csv" +_train_output_path = f"s3://{_bucket}" + + +@patch.object(_TrainingJob, "start_new") +def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(_TrainingJob, "start_new") +def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + supplied_exp_cfg = { + "ExperimentName": "my-supplied-exp-name", + "TrialName": "my-supplied-run-group-name", + "RunName": "my-supplied-run-name", + } + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + experiment_config=supplied_exp_cfg, + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session): + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + def train(): + with load_run(sagemaker_session=sagemaker_session) as run_load: + assert run_load == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + run_load.log_parameter("foo", "bar") + run_load.log_parameter("whizz", 1) + + with run_obj: + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + train() + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + run_obj.log_parameters({"a": "b", "c": 2}) + + assert run_obj._trial_component.parameters["foo"] == "bar" + assert run_obj._trial_component.parameters["whizz"] == 1 + assert run_obj._trial_component.parameters["a"] == "b" + assert run_obj._trial_component.parameters["c"] == 2 + + # Verify separating load_run and with statement in different lines still work + run_load2 = load_run(sagemaker_session=sagemaker_session) + with run_load2: + assert run_load2 == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + +def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +def test_nested_run_load_context(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with load_run(): + run_load = load_run() + with run_load: + pass + assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..f6996fefc3 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial.py @@ -0,0 +1,276 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments._api_types import TrialSummary +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_trial.return_value = {"ExperimentName": "experiment-name-value"} + trial_obj = _Trial.load(trial_name="name-value", sagemaker_session=sagemaker_session) + assert trial_obj.trial_name == "name-value" + assert trial_obj.experiment_name == "experiment-name-value" + client.describe_trial.assert_called_with(TrialName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", ExperimentName="experiment-name-value" + ) + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + tags = [{"Key": "foo", "Value": "bar"}] + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + tags=tags, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", + ExperimentName="experiment-name-value", + Tags=[{"Key": "foo", "Value": "bar"}], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial(sagemaker_session, trial_name="foo") + client.delete_trial.return_value = {} + obj.delete() + client.delete_trial.assert_called_with(TrialName="foo") + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial( + sagemaker_session, + trial_name="foo", + experiment_name="whizz", + display_name="bar", + tags=[{"Key": "foo", "Value": "bar"}], + ) + client.update_trial.return_value = {} + obj.save() + + client.update_trial.assert_called_with( + TrialName="foo", + DisplayName="bar", + ) + + +def test_add_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.add_trial_component("foo") + client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.add_trial_component(tc) + client.associate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +def test_remove_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.remove_trial_component("foo") + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName="foo" + ) + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.remove_trial_component(tc) + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +@patch("sagemaker.experiments.trial._Trial.load") +def test_load_or_create_when_exist(mock_load): + sagemaker_session = Session() + trial_name = "trial_name" + exp_name = "exp_name" + + # The trial exists and experiment matches + mock_load.return_value = _Trial( + trial_name=trial_name, + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + ) + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(trial_name, sagemaker_session) + + # The trial exists but experiment does not match + mock_load.return_value = _Trial( + trial_name=trial_name, + exp_name="another_exp_name", + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as err: + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + assert "The given experiment_name {} does not match that in the loaded trial".format( + exp_name + ) in str(err) + + +@patch("sagemaker.experiments.trial._Trial.load") +@patch("sagemaker.experiments.trial._Trial.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + trial_name = "trial_name" + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + trial_name=trial_name, + experiment_name=exp_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(**{}) + + +def test_list_trials_with_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(experiment_name="foo", sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(ExperimentName="foo") + + +def test_list_trials_with_trial_component_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list( + _Trial.list(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + ) + client.list_trials.assert_called_with(TrialComponentName="tc-foo") diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..c14663893e --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,384 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest.mock + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import ( + TrialComponentSearchResult, + Parent, + _TrialComponentStatusType, +) +from sagemaker.experiments.trial_component import _TrialComponent + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + obj = _TrialComponent.create( + trial_component_name="foo", display_name="bar", sagemaker_session=sagemaker_session + ) + client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar") + assert "foo" == obj.trial_component_name + assert "bar" == obj.display_name + assert "bazz" == obj.trial_component_arn + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + tags = [{"Key": "foo", "Value": "bar"}] + _TrialComponent.create( + trial_component_name="foo", + display_name="bar", + sagemaker_session=sagemaker_session, + tags=tags, + ) + client.create_trial_component.assert_called_with( + TrialComponentName="foo", DisplayName="bar", Tags=tags + ) + + +def test_load(sagemaker_session): + now = datetime.datetime.now(datetime.timezone.utc) + client = sagemaker_session.sagemaker_client + client.describe_trial_component.return_value = { + "TrialComponentArn": "A", + "TrialComponentName": "B", + "DisplayName": "C", + "Status": {"PrimaryStatus": _TrialComponentStatusType.InProgress.value, "Message": "D"}, + "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}}, + "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}}, + "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}}, + "Metrics": [ + { + "MetricName": "J", + "Count": 1, + "Min": 1.0, + "Max": 2.0, + "Avg": 3.0, + "StdDev": 4.0, + "SourceArn": "K", + "Timestamp": now, + } + ], + } + obj = _TrialComponent.load(trial_component_name="foo", sagemaker_session=sagemaker_session) + client.describe_trial_component.assert_called_with(TrialComponentName="foo") + assert "A" == obj.trial_component_arn + assert "B" == obj.trial_component_name + assert "C" == obj.display_name + assert ( + _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="D" + ) + == obj.status + ) + assert {"E": 1.0, "F": "G"} == obj.parameters + assert {"H": _api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")} + assert { + "I": _api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain") + } + assert [ + _api_types.TrialComponentMetricSummary( + metric_name="J", + count=1, + min=1.0, + max=2.0, + avg=3.0, + std_dev=4.0, + source_arn="K", + timestamp=now, + ) + ] + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent( + sagemaker_session, + trial_component_name="foo", + display_name="bar", + parameters_to_remove=["E"], + input_artifacts_to_remove=["F"], + output_artifacts_to_remove=["G"], + ) + client.update_trial_component.return_value = {} + obj.save() + + client.update_trial_component.assert_called_with( + TrialComponentName="foo", + DisplayName="bar", + Parameters={}, + ParametersToRemove=["E"], + InputArtifacts={}, + InputArtifactsToRemove=["F"], + OutputArtifacts={}, + OutputArtifactsToRemove=["G"], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + obj.delete() + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_delete_with_force_disassociate(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + + client.list_trials.side_effect = [ + {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"}, + {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]}, + ] + + obj.delete(force_disassociate=True) + expected_calls = [ + unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), + ] + assert expected_calls == client.disassociate_trial_component.mock_calls + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_list(sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10) + ], + "NextToken": "100", + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10, 20) + ] + }, + ] + + expected = [ + _api_types.TrialComponentSummary( + trial_component_name="A" + str(i), + trial_component_arn="B" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=_api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(20) + ] + result = list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + source_arn="foo", + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + assert expected == result + expected_calls = [ + unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), + unittest.mock.call( + NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo" + ), + ] + assert expected_calls == client.list_trial_components.mock_calls + + +def test_list_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [] + } + assert [] == list(_TrialComponent.list(sagemaker_session=sagemaker_session)) + + +def test_list_trial_components_call_args(sagemaker_session): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = "foo-trial" + experiment_name = "foo-experiment" + next_token = "thetoken" + max_results = 99 + + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = {} + assert [] == list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + next_token=next_token, + max_results=max_results, + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + expected_calls = [ + unittest.mock.call( + TrialName="foo-trial", + ExperimentName="foo-experiment", + CreatedBefore=created_before, + CreatedAfter=created_after, + SortBy="CreationTime", + SortOrder="Ascending", + NextToken="thetoken", + MaxResults=99, + ) + ] + assert expected_calls == client.list_trial_components.mock_calls + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + tc_name = "tc_name" + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + assert is_existed + mock_load.assert_called_once_with( + tc_name, + sagemaker_session, + ) + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +@patch("sagemaker.experiments.trial_component._TrialComponent.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + tc_name = "tc_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + + assert not is_existed + mock_create.assert_called_once_with( + trial_component_name=tc_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_search(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [ + { + "ExperimentName": "e-1", + "TrialName": "t-1", + }, + { + "ExperimentName": "e-2", + "TrialName": "t-2", + }, + ], + } + }, + { + "TrialComponent": { + "TrialComponentName": "tc-2", + "TrialComponentArn": "arn::tc-2", + "DisplayName": "TC2", + } + }, + ] + } + expected = [ + TrialComponentSearchResult( + trial_component_name="tc-1", + trial_component_arn="arn::tc-1", + display_name="TC1", + parents=[ + Parent(experiment_name="e-1", trial_name="t-1"), + Parent(experiment_name="e-2", trial_name="t-2"), + ], + ), + TrialComponentSearchResult( + trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2" + ), + ] + assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session)) diff --git a/tests/unit/sagemaker/experiments/test_utils.py b/tests/unit/sagemaker/experiments/test_utils.py new file mode 100644 index 0000000000..a63c96c0fe --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_utils.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from src.sagemaker.experiments._utils import resolve_artifact_name, guess_media_type + + +def test_resolve_artifact_name(): + file_names = { + "a": "a", + "a.txt": "a.txt", + "b.": "b.", + ".c": ".c", + "/x/a/a.txt": "a.txt", + "/a/b/c.": "c.", + "./.a": ".a", + "../b.txt": "b.txt", + "~/a.txt": "a.txt", + "c/d.txt": "d.txt", + } + for file_name, artifact_name in file_names.items(): + assert artifact_name == resolve_artifact_name(file_name) + + +def test_guess_media_type(): + assert "text/plain" == guess_media_type("foo.txt") diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index c391d45382..0088e34c58 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -48,6 +48,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 2e7576421f..fea80b7ea9 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -56,6 +56,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index af46cf4360..d35c0a51dd 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -52,6 +52,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 5aef9316da..7645c4fe23 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 7517f3a641..1ce58a19b4 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/utilities/test_search_expression.py b/tests/unit/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..98a52a992a --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker.utilities.search_expression import ( + Filter, + Operator, + NestedFilter, + SearchExpression, + BooleanOperator, +) + + +def test_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + + assert { + "Name": "learning_rate", + "Operator": "Equals", + "Value": "0.1", + } == search_filter.to_boto() + + +def test_partial_filters(): + search_filter = Filter(name="learning_rate") + + assert {"Name": "learning_rate"} == search_filter.to_boto() + + +def test_nested_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + filters = [search_filter] + nested_filters = NestedFilter(property_name="hyper_param", filters=filters) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } == nested_filters.to_boto() + + +def test_search_expression(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter]) + search_expression = SearchExpression( + filters=[search_filter], + nested_filters=[nested_filter], + sub_expressions=[], + boolean_operator=BooleanOperator.AND, + ) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedFilters": [ + { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } + ], + "SubExpressions": [], + "Operator": "And", + } == search_expression.to_boto() + + +def test_illegal_search_expression(): + with pytest.raises( + ValueError, match="You must specify at least one subexpression, filter, or nested filter" + ): + SearchExpression() diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py index feadaa03dc..54b354b71e 100644 --- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py +++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py @@ -16,10 +16,6 @@ import re import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.clarify import ( DataConfig, BiasConfig, @@ -50,46 +46,6 @@ _S3_ANALYSIS_CONFIG_OUTPUT_PATH = "s3://my_bucket/analysis_cfg_output" -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_DEFAULT_BUCKET, - ) - - _expected_data_bias_dsl = { "Name": "DataBiasCheckStep", "Type": "ClarifyCheck", diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index 6f0be2ccca..a36207b241 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -19,9 +19,6 @@ from enum import Enum -from mock.mock import Mock, PropertyMock - -import sagemaker from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.conditions import ConditionGreaterThan from sagemaker.workflow.entities import ( @@ -58,46 +55,6 @@ def custom_entity_list(): return [CustomEntity(1), CustomEntity(2)] -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value="role") - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name="us-west-2") - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket="my-bucket", - ) - - def test_entity(custom_entity): request_struct = {"foo": 1} assert custom_entity.to_request() == request_struct diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index b60e2de8fa..dc104d71df 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -15,10 +15,6 @@ import json import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.model_monitor import DatasetFormat from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -31,49 +27,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig -_REGION = "us-west-2" _ROLE = "DummyRole" -_BUCKET = "my-bucket" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) _expected_data_quality_dsl = { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 9887d43078..ba712d11d7 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,15 +16,10 @@ import json import pytest -import sagemaker import os import warnings -from mock import ( - Mock, - PropertyMock, - patch, -) +from mock import patch from sagemaker.debugger import ProfilerConfig from sagemaker.estimator import Estimator @@ -94,46 +89,6 @@ def create_predictor(self, endpoint_name): return Predictor(endpoint_name, self.sagemaker_session) -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) - - @pytest.fixture def script_processor(sagemaker_session): return ScriptProcessor( diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 82b154317d..44b5818fc8 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -225,6 +225,9 @@ def test_fit_ndarray(time, sagemaker_session): assert mock_object.put.call_count == 4 + called_args = sagemaker_session.train.call_args + assert not called_args[1]["experiment_config"] + def test_fit_pass_experiment_config(sagemaker_session): kwargs = dict(COMMON_ARGS) @@ -239,12 +242,18 @@ def test_fit_pass_experiment_config(sagemaker_session): labels = [99, 85, 87, 2] pca.fit( pca.record_set(np.array(train), np.array(labels)), - experiment_config={"ExperimentName": "exp"}, + experiment_config={ + "ExperimentName": "exp", + "RunName": "rn", + }, ) called_args = sagemaker_session.train.call_args - assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"} + assert called_args[1]["experiment_config"] == { + "ExperimentName": "exp", + "RunName": "rn", + } def test_build_shards(): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 34e6a43fcf..868da88d78 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2489,7 +2489,12 @@ def test_start_new(sagemaker_session): hyperparameters=hyperparameters, ) - exp_config = {"ExperimentName": "exp", "TrialName": "t", "TrialComponentDisplayName": "tc"} + exp_config = { + "ExperimentName": "exp", + "TrialName": "t", + "TrialComponentDisplayName": "tc", + "RunName": "rn", + } started_training_job = training_job.start_new(estimator, inputs, experiment_config=exp_config) called_args = sagemaker_session.train.call_args @@ -2680,6 +2685,7 @@ def test_unsupported_type_in_dict(): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } } ) @@ -2884,6 +2890,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", }, ) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 99b0e839b7..9ba3e17ff3 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -62,6 +62,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"} diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 082f699d63..c8aad13774 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 4efc2e5bf8..2035636e76 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -49,6 +49,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d7c94470f5..ec4a21cbc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -588,11 +588,16 @@ def test_user_agent_injected(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi(boto_session): @@ -607,10 +612,14 @@ def test_user_agent_injected_with_nbi(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi_ioerror(boto_session): @@ -625,11 +634,16 @@ def test_user_agent_injected_with_nbi_ioerror(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_training_input_all_defaults(): @@ -700,6 +714,7 @@ def test_training_input_all_arguments(): "ExperimentName": "dummyExp", "TrialName": "dummyT", "TrialComponentDisplayName": "dummyTC", + "RunName": "dummyRN", } MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60} diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 13cc755336..c3e984e0b7 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -51,6 +51,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0eb81be584..8bcbed41c2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -25,10 +25,12 @@ from boto3 import exceptions import botocore import pytest -from mock import call, patch, Mock, MagicMock +from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings +from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -795,3 +797,63 @@ def test_start_waiting(capfd): out, _ = capfd.readouterr() assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out + + +def test_retry_with_backoff(): + callable_func = Mock() + + # Invalid input + with pytest.raises(ValueError) as value_err: + retry_with_backoff(callable_func, 0) + assert "The num_attempts must be >= 1" in str(value_err) + callable_func.assert_not_called() + + # All retries fail + run_err_msg = "Test Retry Error" + callable_func.side_effect = RuntimeError(run_err_msg) + with pytest.raises(RuntimeError) as run_err: + retry_with_backoff(callable_func, 2) + assert run_err_msg in str(run_err) + + # One retry passes + func_return_val = "Test Return" + callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val] + assert retry_with_backoff(callable_func, 2) == func_return_val + + # No retry + callable_func.side_effect = None + callable_func.return_value = func_return_val + assert retry_with_backoff(callable_func, 2) == func_return_val + + +def test_check_and_get_run_experiment_config(): + supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"} + run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"} + + # No user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg1 = check_and_get_run_experiment_config(None) + assert exp_cfg1 is None + + # With user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg2 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg2 == supplied_exp_cfg + + run = Mock() + type(run).experiment_config = PropertyMock(return_value=run_exp_cfg) + _RunContext.add_run_object(run) + + try: + # No user supplied exp config and with current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg3 = check_and_get_run_experiment_config(None) + assert exp_cfg3 == run_exp_cfg + + # With user supplied exp config and current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg4 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg4 == supplied_exp_cfg + finally: + # Clean up the global static variable in case it affects other tests + _RunContext.drop_current_run() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 82f27c19ae..d58c4992cd 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } From 1cbfc8389f75323d779e560d12b15f163a23c7af Mon Sep 17 00:00:00 2001 From: Tejas Chumbalkar <34728580+tejaschumbalkar@users.noreply.github.com> Date: Wed, 14 Dec 2022 10:22:22 -0800 Subject: [PATCH 058/526] feature: Add support for TF2.9.2 training images (#3178) --- src/sagemaker/fw_utils.py | 1 + src/sagemaker/image_uri_config/tensorflow.json | 4 ++-- tests/unit/test_fw_utils.py | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 5efe530396..3ba918ea2c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -103,6 +103,7 @@ "2.8.0", "2.9", "2.9.1", + "2.9.2", "2.10", "2.10.0", ], diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index a900aa4fe5..6bb36057fa 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -1820,7 +1820,7 @@ "2.6": "2.6.3", "2.7": "2.7.1", "2.8": "2.8.0", - "2.9": "2.9.1", + "2.9": "2.9.2", "2.10": "2.10.0" }, "versions": { @@ -3273,7 +3273,7 @@ }, "repository": "tensorflow-training" }, - "2.9.1": { + "2.9.2": { "py_versions": [ "py39" ], diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 667d115d58..4654abb928 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -883,6 +883,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled), @@ -915,6 +916,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7.1", "py38", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi), From 881caecd9a45d0facde0913baa74895938c5e788 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 15 Dec 2022 01:19:26 +0000 Subject: [PATCH 059/526] prepare release v2.123.0 --- CHANGELOG.md | 7 +++++++ VERSION | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de20a8a0df..a05b64c96f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v2.123.0 (2022-12-15) + +### Features + + * Add support for TF2.9.2 training images + * Add SageMaker Experiment + ## v2.122.0 (2022-12-14) ### Features diff --git a/VERSION b/VERSION index 6d7f044fa2..bef06dbf6d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.122.1.dev0 +2.123.0 From d543604609f3d0b1f0856d8346c8ecf271203432 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 15 Dec 2022 01:19:27 +0000 Subject: [PATCH 060/526] update development version to v2.123.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index bef06dbf6d..ea5085760e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.123.0 +2.123.1.dev0 From eef679cfc167827b14e47d0e1991e274a16e1ed4 Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Wed, 14 Dec 2022 19:55:56 -0800 Subject: [PATCH 061/526] feature: Added doc update for dataset builder (#3539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add list_feature_groups API (#647) * feat: Feature/get record api (#650) Co-authored-by: Eric Zou * Add delete_record API (#664) * feat: Add DatasetBuilder class (#667) Co-authored-by: Eric Zou * feat: Add to_csv method in DatasetBuilder (#699) * feat: Add pandas.Dataframe as base case (#708) * feat: Add with_feature_group method in DatasetBuilder (#726) * feat: Handle merge and timestamp filters (#727) * feat: Add to_dataframe method in DatasetBuilder (#729) * Address TODOs (#731) * Unit test for DatasetBuilder (#734) * fix: Fix list_feature_groups max_results (#744) * Add integration tests for create_dataset (#743) * feature: Aggregate commits * fix: as_of, event_range, join, default behavior and duplicates… (#764) * Bug fixed - as_of, event_range, join, default behavior and duplicates and tests Bugs: 1. as_of was not working properly on deleted events 2. Same event_time_range 3. Join was not working when including feature names 4. Default sql was returning only most recent, whereas it should all excluding duplicates 5. Include duplicates was not return all non-deleted data 6. instanceof(dataframe) case was also applied to non-df cases while join 7. Include column was returning unnecessary columns. * Fix on pylint error * Fix on include_duplicated_records for panda data frames * Fix format issue for black * Bug fixed related to line break * Bug fix related to dataframe and inclde_deleted_record and include_duplicated_record * Addressed comments and code refactored * changed to_csv to to_csv_file and added error messages for query limit and recent record limit * Revert a change which was not intended * Resolved the leak of feature group deletion in integration test * Added doc update for dataset builder * Fix the issue in doc Co-authored-by: Yiming Zou Co-authored-by: Brandon Chatham Co-authored-by: Eric Zou Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> --- doc/api/prep_data/feature_store.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 1980a0b069..0e9bf25586 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -72,3 +72,11 @@ Inputs .. autoclass:: sagemaker.feature_store.inputs.FeatureValue :members: :show-inheritance: + + +Dataset Builder +*************** + +.. autoclass:: sagemaker.feature_store.dataset_builder.DatasetBuilder + :members: + :show-inheritance: From 019d5a4b232cd4d287dff35c6a8ba9681ed4c0ca Mon Sep 17 00:00:00 2001 From: mariumof <99500633+mariumof@users.noreply.github.com> Date: Thu, 15 Dec 2022 12:38:55 -0800 Subject: [PATCH 062/526] feature: Add disable_profiler field in config and propagate changes (#3523) Co-authored-by: Marius Moisescu --- src/sagemaker/debugger/profiler_config.py | 4 + src/sagemaker/estimator.py | 23 +- .../integ/sagemaker/workflow/test_workflow.py | 4 - tests/integ/test_profiler.py | 40 --- .../sagemaker/huggingface/test_estimator.py | 8 +- .../sagemaker/tensorflow/test_estimator.py | 8 +- .../test_huggingface_pytorch_compiler.py | 8 +- .../test_huggingface_tensorflow_compiler.py | 8 +- .../test_pytorch_compiler.py | 12 +- .../test_tensorflow_compiler.py | 8 +- .../workflow/test_step_collections.py | 4 + tests/unit/sagemaker/workflow/test_steps.py | 3 +- .../sagemaker/workflow/test_training_step.py | 25 -- tests/unit/sagemaker/workflow/test_utils.py | 2 + tests/unit/test_chainer.py | 8 +- tests/unit/test_estimator.py | 249 ++++++++++++------ tests/unit/test_mxnet.py | 8 +- tests/unit/test_pytorch.py | 8 +- tests/unit/test_rl.py | 8 +- tests/unit/test_sklearn.py | 8 +- tests/unit/test_xgboost.py | 8 +- 21 files changed, 211 insertions(+), 243 deletions(-) diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 3d4a24e8d1..561de38b9f 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -32,6 +32,7 @@ def __init__( s3_output_path: Optional[Union[str, PipelineVariable]] = None, system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, framework_profile_params: Optional[FrameworkProfile] = None, + disable_profiler: Optional[Union[str, PipelineVariable]] = False, ): """Initialize a ``ProfilerConfig`` instance. @@ -78,6 +79,7 @@ class and SageMaker Framework estimators. self.s3_output_path = s3_output_path self.system_monitor_interval_millis = system_monitor_interval_millis self.framework_profile_params = framework_profile_params + self.disable_profiler = disable_profiler def _to_request_dict(self): """Generate a request dictionary using the parameters provided when initializing the object. @@ -91,6 +93,8 @@ def _to_request_dict(self): if self.s3_output_path is not None: profiler_config_request["S3OutputPath"] = self.s3_output_path + profiler_config_request["DisableProfiler"] = self.disable_profiler + if self.system_monitor_interval_millis is not None: profiler_config_request[ "ProfilingIntervalInMilliseconds" diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e3b06950aa..8ed9b724a5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -938,26 +938,29 @@ def _prepare_collection_configs(self): def _prepare_profiler_for_training(self): """Set necessary values and do basic validations in profiler config and profiler rules. - When user explicitly set rules to an empty list, default profiler rule won't be enabled. - Default profiler rule will be enabled in supported regions when either: - 1. user doesn't specify any rules, i.e., rules=None; or - 2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)] + No default profiler rule will be used. The user needs to specify rules explicitly """ if self.disable_profiler: - if self.profiler_config: - raise RuntimeError("profiler_config cannot be set when disable_profiler is True.") + if self.profiler_config and not self.profiler_config.disable_profiler: + raise RuntimeError( + "profiler_config.disable_profiler cannot be False" + + " when disable_profiler is True." + ) if self.profiler_rules: raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.") elif _region_supports_profiler(self.sagemaker_session.boto_region_name): if self.profiler_config is None: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) if self.rules is None or (self.rules and not self.profiler_rules): - self.profiler_rules = [get_default_profiler_rule()] + self.profiler_rules = [] if self.profiler_config and not self.profiler_config.s3_output_path: self.profiler_config.s3_output_path = self.output_path self.profiler_rule_configs = self._prepare_profiler_rules() + # if profiler_config is still None, it means the job has profiler disabled + if self.profiler_config is None: + self.profiler_config = ProfilerConfig(disable_profiler=True) def _prepare_profiler_rules(self): """Set any necessary values in profiler rules, if they are provided.""" @@ -1048,7 +1051,7 @@ def latest_job_profiler_artifacts_path(self): error_message="""Cannot get the profiling output artifacts path. The Estimator is not associated with a training job.""" ) - if self.profiler_config is not None: + if self.profiler_config is not None and not self.profiler_config.disable_profiler: return os.path.join( self.profiler_config.s3_output_path, self.latest_training_job.name, @@ -1895,8 +1898,8 @@ def enable_default_profiling(self): else: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) - self.profiler_rules = [get_default_profiler_rule()] - self.profiler_rule_configs = self._prepare_profiler_rules() + self.profiler_rules = [] + self.profiler_rule_configs = [] _TrainingJob.update( self, self.profiler_rule_configs, self.profiler_config._to_request_dict() diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 44f4e2d26e..bd24b653ae 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1269,8 +1269,6 @@ def test_caching_behavior( # create pipeline pipeline.create(role) definition = json.loads(pipeline.definition()) - # delete profiler config for assertions as it will contain a timestamp - del definition["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] # verify input path expected_abalone_input_path = f"{pipeline_name}/{step_process.name}" f"/input/abalone_data" @@ -1295,7 +1293,6 @@ def test_caching_behavior( # verify no changes definition2 = json.loads(pipeline.definition()) - del definition2["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition == definition2 # add dummy file to source_dir @@ -1306,7 +1303,6 @@ def test_caching_behavior( # verify changes definition3 = json.loads(pipeline.definition()) - del definition3["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition != definition3 finally: diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index bddd53e20c..7d3fdb2d7b 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import os -import re import time import uuid @@ -22,7 +21,6 @@ from sagemaker.debugger import ( DebuggerHookConfig, FrameworkProfile, - get_rule_container_image_uri, ProfilerConfig, ProfilerRule, Rule, @@ -93,8 +91,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert ( job_description["ProfilerConfig"] == ProfilerConfig( @@ -103,13 +99,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - with pytest.raises(ValueError) as error: mx.enable_default_profiling() assert "Debugger monitoring is already enabled." in str(error) @@ -155,18 +144,9 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilerConfig") == profiler_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) mx.update_profiler( @@ -178,13 +158,6 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( assert job_description["ProfilerConfig"]["S3OutputPath"] == profiler_config.s3_output_path assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 - profiler_report_rule_config = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_report_rule_config["RuleConfigurationName"]) - assert profiler_report_rule_config["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_report_rule_config["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( sagemaker_session, @@ -225,8 +198,6 @@ def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilingStatus") == "Enabled" assert ( job_description.get("ProfilerConfig") @@ -298,8 +269,6 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description["ProfilerConfig"] == profiler_config._to_request_dict() assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" @@ -387,13 +356,6 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics( == updated_framework_profile.profiling_parameters ) - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_disable_profiler_then_enable_default_profiling( sagemaker_session, @@ -431,12 +393,10 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( ) job_description = mx.latest_training_job.describe() - assert job_description.get("ProfilerConfig") is None assert job_description.get("ProfilerRuleConfigurations") is None assert job_description.get("ProfilingStatus") == "Disabled" _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) - mx.enable_default_profiling() job_description = mx.latest_training_job.describe() diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 0088e34c58..072eefeb83 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -143,14 +143,8 @@ def _create_train_job(version, base_framework_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index fea80b7ea9..771b18b35a 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -136,14 +136,8 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd "metric_definitions": None, "environment": None, "experiment_config": None, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index d35c0a51dd..656730a47c 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -145,14 +145,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 7645c4fe23..c3684ac649 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -143,14 +143,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 0fe2402695..068bb4e4b9 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -137,14 +137,10 @@ def _create_train_job(version, instance_type, training_compiler_config, instance "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], - "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "profiler_config": { + "DisableProfiler": False, + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, } diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 1ce58a19b4..a5c14b1626 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -145,14 +145,8 @@ def _create_train_job(framework_version, instance_type, training_compiler_config "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 2bf47a79d0..95738c99ca 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -796,6 +796,7 @@ def test_register_model_with_model_repack_with_estimator( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "dependencies": f'"{dummy_requirements}"', @@ -923,6 +924,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "model_archive": '"s3://my-bucket/model.tar.gz"', @@ -1052,6 +1054,7 @@ def test_register_model_with_model_repack_with_pipeline_model( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "dependencies": "null", "inference_script": '"dummy_script.py"', @@ -1243,6 +1246,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator): "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + "sagemaker-scikit-learn:0.23-1-cpu-py3", }, + "ProfilerConfig": {"DisableProfiler": True}, "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index ba712d11d7..f2046cc00f 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -329,6 +329,7 @@ def test_training_step_base_estimator(sagemaker_session): "CollectionConfigurations": [], }, "ProfilerConfig": { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}}, }, @@ -438,7 +439,7 @@ def test_training_step_tensorflow(sagemaker_session): "sagemaker_instance_type": {"Get": "Parameters.InstanceType"}, "sagemaker_distributed_dataparallel_custom_mpi_options": '""', }, - "ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"}, + "ProfilerConfig": {"DisableProfiler": False, "S3OutputPath": "s3://my-bucket/"}, }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 3e8b57b069..7f8e6b0c62 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -401,10 +401,6 @@ def test_training_step_with_estimator( } step_definition = json.loads(pipeline.definition())["Steps"][0] - # delete profiler rule configurations because of timestamp collision - del step_definition["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_definition == { "Name": "MyTrainingStep", "Description": "TrainingStep description", @@ -428,7 +424,6 @@ def test_training_step_with_estimator( # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_definition == step_def2 @@ -537,10 +532,6 @@ def test_training_step_with_framework_estimator( del expected_step_args["OutputDataConfig"]["S3OutputPath"] del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del expected_step_args["ProfilerRuleConfigurations"] - if "sagemaker_s3_output" in step_args["HyperParameters"]: del expected_step_args["HyperParameters"]["sagemaker_s3_output"] del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] @@ -555,7 +546,6 @@ def test_training_step_with_framework_estimator( step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] assert step_def == step_def2 @@ -608,10 +598,6 @@ def test_training_step_with_framework_estimator_local_code( del expected_step_args["OutputDataConfig"]["S3OutputPath"] del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del expected_step_args["ProfilerRuleConfigurations"] - if "sagemaker_s3_output" in step_args["HyperParameters"]: del expected_step_args["HyperParameters"]["sagemaker_s3_output"] del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] @@ -626,7 +612,6 @@ def test_training_step_with_framework_estimator_local_code( step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] assert step_def == step_def2 @@ -701,10 +686,6 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_def == { "Name": "MyTrainingStep", "Type": "Training", @@ -714,7 +695,6 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_def == step_def2 @@ -789,10 +769,6 @@ def test_training_step_with_algorithm_base_local_code( del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_def == { "Name": "MyTrainingStep", "Type": "Training", @@ -802,7 +778,6 @@ def test_training_step_with_algorithm_base_local_code( # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_def == step_def2 diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index c8d86c5866..d1b81f3148 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -107,6 +107,7 @@ def test_repack_model_step(estimator): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", @@ -188,6 +189,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 7cc973440f..eca4a9bf80 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -150,14 +150,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 868da88d78..8b771f9184 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -25,7 +25,10 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, PropertyMock from sagemaker.huggingface.estimator import HuggingFace -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_BUCKET_NAME_SET, + JUMPSTART_RESOURCE_BASE_NAME, +) from sagemaker.jumpstart.enums import JumpStartTag import sagemaker.local @@ -106,7 +109,11 @@ "training_steps": "100", }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", - "ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, "EnableNetworkIsolation": False, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", @@ -143,7 +150,10 @@ } MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir" MOCKED_PIPELINE_CONFIG = _PipelineConfig( - "test-pipeline", "test-training-step", "code-hash-0123456789", "config-hash-0123456789" + "test-pipeline", + "test-training-step", + "code-hash-0123456789", + "config-hash-0123456789", ) @@ -247,7 +257,9 @@ def pipeline_session(): session_mock.resource.return_value = resource_mock session_mock.client.return_value = client_mock return PipelineSession( - boto_session=session_mock, sagemaker_client=client_mock, default_bucket=BUCKET_NAME + boto_session=session_mock, + sagemaker_client=client_mock, + default_bucket=BUCKET_NAME, ) @@ -322,7 +334,11 @@ def test_framework_all_init_args(sagemaker_session): }, "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], "encrypt_inter_container_traffic": True, - "environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}, + "environment": { + "env_key1": "env_val1", + "env_key2": "env_val2", + "env_key3": "env_val3", + }, "experiment_config": None, "checkpoint_s3_uri": "s3://bucket/checkpoint", "checkpoint_local_path": "file://local/checkpoint", @@ -379,7 +395,8 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -405,18 +422,23 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } def test_framework_with_debugger_and_custom_rule(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) debugger_custom_rule = Rule.custom( name="CustomRule", @@ -536,7 +558,8 @@ def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session): def test_framework_with_only_debugger_hook_config(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -574,15 +597,9 @@ def test_framework_without_debugger_and_profiler(time, sagemaker_session): } assert "debugger_rule_configs" not in args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_debugger_and_profiler_rules(sagemaker_session): @@ -591,7 +608,8 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -639,18 +657,25 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ { "RuleConfigurationName": "CustomProfilerReportRule", "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport", "CPUBottleneck_threshold": "90"}, + "RuleParameters": { + "rule_to_invoke": "ProfilerReport", + "CPUBottleneck_threshold": "90", + }, }, { "InstanceType": "c4.4xlarge", @@ -679,6 +704,7 @@ def test_framework_with_only_profiler_rule_specified(sagemaker_session): sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ @@ -711,16 +737,10 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), "ProfilingIntervalInMilliseconds": 1000, } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) @@ -745,7 +765,9 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): f.fit("s3://mydata") sms.train.assert_called_once() _, args = sms.train.call_args - assert args.get("profiler_config") is None + # assert args.get("profiler_config") == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True assert args.get("profiler_rule_configs") is None @@ -865,7 +887,10 @@ def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session) disable_profiler=True, ) f.fit("s3://mydata") - assert "profiler_config cannot be set when disable_profiler is True." in str(error) + # assert "profiler_config cannot be set when disable_profiler is True." in str(error) + assert "profiler_config.disable_profiler cannot be False when disable_profiler is True." in str( + error + ) def test_framework_with_profiler_rule_and_profiler_disabled(sagemaker_session): @@ -927,15 +952,9 @@ def test_framework_with_enabling_default_profiling( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @patch("time.time", return_value=TIME) @@ -960,15 +979,9 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://custom/", } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_disabling_profiling_when_profiler_is_already_disabled( @@ -1001,7 +1014,9 @@ def test_framework_with_disabling_profiling(sagemaker_session, training_job_desc f.disable_profiling() sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"DisableProfiler": True} + # assert args["profiler_config"] == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True def test_framework_with_update_profiler_when_no_training_job(sagemaker_session): @@ -1058,6 +1073,7 @@ def test_framework_with_update_profiler_config(sagemaker_session): sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, } assert "profiler_rule_configs" not in args @@ -1086,7 +1102,7 @@ def test_framework_with_update_profiler_report_rule(sagemaker_session): "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, } ] - assert "profiler_config" not in args + assert args["profiler_config"]["DisableProfiler"] is False def test_framework_with_disable_framework_metrics(sagemaker_session): @@ -1101,11 +1117,16 @@ def test_framework_with_disable_framework_metrics(sagemaker_session): f.update_profiler(disable_framework_metrics=True) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingParameters": {}} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingParameters": {}, + } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_system_metrics(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_system_metrics( + sagemaker_session, +): f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -1118,13 +1139,16 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, "ProfilingParameters": {}, } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_framework_params(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_framework_params( + sagemaker_session, +): with pytest.raises(ValueError) as error: f = DummyFramework( entry_point=SCRIPT_PATH, @@ -1160,7 +1184,10 @@ def test_framework_with_update_profiler_config_and_profiler_rule(sagemaker_sessi f.update_profiler(rules=[profiler_custom_rule], system_monitor_interval_millis=1000) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingIntervalInMilliseconds": 1000} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingIntervalInMilliseconds": 1000, + } assert args["profiler_rule_configs"] == [ { "InstanceType": "c4.4xlarge", @@ -1659,7 +1686,10 @@ def test_start_new_wait_called(strftime, sagemaker_session): def test_attach_framework(sagemaker_session, training_job_description): - training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + training_job_description["VpcConfig"] = { + "Subnets": ["foo"], + "SecurityGroupIds": ["bar"], + } training_job_description["EnableNetworkIsolation"] = True framework_estimator = DummyFramework.attach( @@ -1753,7 +1783,8 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag( def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): base_job_name = "neo" framework_estimator = DummyFramework.attach( - training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session + training_job_name=utils.name_from_base("neo"), + sagemaker_session=sagemaker_session, ) assert framework_estimator.base_job_name == base_job_name @@ -1948,7 +1979,8 @@ def test_git_support_bad_repo_url_format(sagemaker_session): @patch( "sagemaker.git_utils.git_clone_repo", side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" + returncode=1, + cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir", ), ) def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): @@ -1973,7 +2005,11 @@ def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): ), ) def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + git_config = { + "repo": GIT_REPO, + "branch": "branch-that-does-not-exist", + "commit": COMMIT, + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -1994,7 +2030,11 @@ def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): ), ) def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + git_config = { + "repo": GIT_REPO, + "branch": BRANCH, + "commit": "commit-sha-that-does-not-exist", + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -2137,7 +2177,11 @@ def test_git_support_with_token_2fa(git_clone_repo, sagemaker_session): }, ) def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2159,7 +2203,11 @@ def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session) ), ) def test_git_support_ssh_passphrase_required(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2457,7 +2505,9 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem ) create_model.assert_called_with( - vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True + vpc_config_override=new_vpc_config, + model_kms_key=kms_key, + enable_network_isolation=True, ) assert transformer.strategy == strategy @@ -2635,14 +2685,7 @@ def test_unsupported_type_in_dict(): "input_config": None, "input_mode": "File", "output_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], + "profiler_config": {"DisableProfiler": False, "S3OutputPath": OUTPUT_PATH}, "resource_config": { "InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE, @@ -2749,7 +2792,11 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): @patch("sagemaker.estimator.name_from_base") def test_fit_deploy_tags(name_from_base, sagemaker_session): estimator = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) estimator.fit() @@ -3197,7 +3244,10 @@ def test_generic_training_job_analytics(sagemaker_session): "TrainingInputMode": "File", "MetricDefinitions": [ {"Name": "train:loss", "Regex": "train_loss=([0-9]+\\.[0-9]+)"}, - {"Name": "validation:loss", "Regex": "valid_loss=([0-9]+\\.[0-9]+)"}, + { + "Name": "validation:loss", + "Regex": "valid_loss=([0-9]+\\.[0-9]+)", + }, ], }, }, @@ -3228,7 +3278,11 @@ def test_generic_create_model_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) assert e.get_vpc_config() is None @@ -3254,7 +3308,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) @@ -3274,7 +3332,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): def test_generic_deploy_accelerator_type(sagemaker_session): e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) @@ -3617,7 +3679,13 @@ def test_file_output_path_not_supported_outside_local_mode(session_class): session_class.return_value = session with pytest.raises(RuntimeError): - Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path="file:///tmp/model") + Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path="file:///tmp/model", + ) def test_prepare_init_params_from_job_description_with_image_training_job(): @@ -3726,7 +3794,10 @@ def test_prepare_for_training_with_name_based_on_image(sagemaker_session): @patch("sagemaker.algorithm.AlgorithmEstimator.validate_train_spec", Mock()) -@patch("sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", Mock(return_value={})) +@patch( + "sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", + Mock(return_value={}), +) def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022", @@ -3741,7 +3812,9 @@ def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir(pipeline_session): +def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir( + pipeline_session, +): # script_uri is NOT provided -> use new cache key behavior that builds path using pipeline name + code_hash image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" model_uri = "s3://someprefix2/models/model.tar.gz" @@ -4211,7 +4284,10 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_tags( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4240,13 +4316,20 @@ def test_all_framework_estimators_add_jumpstart_tags( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, @@ -4362,7 +4445,10 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_base_name( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4391,13 +4477,20 @@ def test_all_framework_estimators_add_jumpstart_base_name( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 9ba3e17ff3..f12d8e160f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -160,14 +160,8 @@ def _get_train_args(job_name): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index c8aad13774..5691834c3a 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -158,14 +158,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 2035636e76..0c0a9c6d64 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -153,14 +153,8 @@ def _create_train_job(toolkit, toolkit_version, framework): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, "retry_strategy": None, diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index c3e984e0b7..430cb484b4 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -140,14 +140,8 @@ def _create_train_job(version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index d58c4992cd..87a853d5d0 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -154,14 +154,8 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } From 097e82947590cc9b2c68d01f155c4bb486e526b8 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Thu, 15 Dec 2022 12:39:35 -0800 Subject: [PATCH 063/526] Use Async Inference Config when available for endpoint update (#3124) Co-authored-by: Navin Soni --- src/sagemaker/session.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ce6a3b99cd..602cd1fd9f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3324,6 +3324,11 @@ def create_endpoint_config_from_existing( if request_data_capture_config_dict is not None: request["DataCaptureConfig"] = request_data_capture_config_dict + if existing_endpoint_config_desc.get("AsyncInferenceConfig") is not None: + request["AsyncInferenceConfig"] = existing_endpoint_config_desc.get( + "AsyncInferenceConfig", None + ) + self.sagemaker_client.create_endpoint_config(**request) def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): From be6111b011b1045e68b18ec1bc84c0dbd9f8fb6a Mon Sep 17 00:00:00 2001 From: Carolyn Wang <32006339+carolynwang@users.noreply.github.com> Date: Thu, 15 Dec 2022 15:43:11 -0500 Subject: [PATCH 064/526] feature: Add p4de to smddp supported instance types (#3483) --- src/sagemaker/fw_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 3ba918ea2c..a91aff1761 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -80,6 +80,7 @@ "ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4d.24xlarge", + "ml.p4de.24xlarge", "local_gpu", ) SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = { From a0258bbaa715295ae15f2bf4c59cbe6eed054b07 Mon Sep 17 00:00:00 2001 From: Miyoung Date: Thu, 15 Dec 2022 13:08:57 -0800 Subject: [PATCH 065/526] documentation: smdistributed libraries release notes (#3543) --- doc/api/training/sdp_versions/latest.rst | 4 +- .../smd_data_parallel_change_log.rst | 50 +++++++++++++--- .../smd_model_parallel_change_log.rst | 60 ++++++++++++++++--- doc/api/training/smp_versions/latest.rst | 4 +- 4 files changed, 100 insertions(+), 18 deletions(-) diff --git a/doc/api/training/sdp_versions/latest.rst b/doc/api/training/sdp_versions/latest.rst index c3fcc5f78e..461f58998f 100644 --- a/doc/api/training/sdp_versions/latest.rst +++ b/doc/api/training/sdp_versions/latest.rst @@ -26,8 +26,8 @@ depending on the version of the library you use. `_ for more information. -Version 1.4.0, 1.4.1, 1.5.0 (Latest) -==================================== +Version 1.4.0, 1.4.1, 1.5.0, 1.6.0 (Latest) +=========================================== .. toctree:: :maxdepth: 1 diff --git a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst index 05eb7220e0..8ff7fabf1c 100644 --- a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst +++ b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst @@ -7,9 +7,51 @@ Release Notes New features, bug fixes, and improvements are regularly made to the SageMaker distributed data parallel library. -SageMaker Distributed Data Parallel 1.5.0 Release Notes +SageMaker Distributed Data Parallel 1.6.0 Release Notes ======================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* New optimized SMDDP AllGather collective to complement the sharded data parallelism technique + in the SageMaker model parallelism library. For more information, see `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Added support for Amazon EC2 ``ml.p4de.24xlarge`` instances. You can run data parallel training jobs + on ``ml.p4de.24xlarge`` instances with the SageMaker data parallelism library’s AllReduce collective. + +**Improvements** + +* General performance improvements of the SMDDP AllReduce collective communication operation. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + + .. code:: + + https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.1/cu113/2022-12-05/smdistributed_dataparallel-1.6.0-cp38-cp38-linux_x86_64.whl + + +---- + +Release History +=============== + +SageMaker Distributed Data Parallel 1.5.0 Release Notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + *Date: Jul. 26. 2022* **Currency Updates** @@ -38,12 +80,6 @@ Binary file of this version of the library for `custom container https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.0/cu113/2022-07-01/smdistributed_dataparallel-1.5.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== - SageMaker Distributed Data Parallel 1.4.1 Release Notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst index 6f89fa45a5..92ccc8c407 100644 --- a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst +++ b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst @@ -6,9 +6,60 @@ New features, bug fixes, and improvements are regularly made to the SageMaker distributed model parallel library. -SageMaker Distributed Model Parallel 1.11.0 Release Notes +SageMaker Distributed Model Parallel 1.13.0 Release Notes ========================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* Sharded data parallelism now supports a new backend for collectives called *SMDDP Collectives*. + For supported scenarios, SMDDP Collectives are on by default for the AllGather operation. + For more information, see + `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Introduced FlashAttention for DistributedTransformer to improve memory usage and computational + performance of models such as GPT2, GPTNeo, GPTJ, GPTNeoX, BERT, and RoBERTa. + +**Bug Fixes** + +* Fixed initialization of ``lm_head`` in DistributedTransformer to use a provided range + for initialization, when weights are not tied with the embeddings. + +**Improvements** + +* When a module has no parameters, we have introduced an optimization to execute + such a module on the same rank as its parent during pipeline parallelism. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + +- For PyTorch 1.12.0 + + .. code:: + + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.1/build-artifacts/2022-12-08-21-34/smdistributed_modelparallel-1.13.0-cp38-cp38-linux_x86_64.whl + +---- + +Release History +=============== + +SageMaker Distributed Model Parallel 1.11.0 Release Notes +--------------------------------------------------------- + *Date: August. 17. 2022* **New Features** @@ -41,12 +92,7 @@ Binary file of this version of the library for `custom container .. code:: - https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.0/build-artifacts/2022-08-12-16-58/smdistributed_modelparallel-1.11.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== + https://sagemaker-distribu SageMaker Distributed Model Parallel 1.10.1 Release Notes --------------------------------------------------------- diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 1a2032c9ed..1eb358b2a3 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,8 +10,8 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.11.0 (Latest) -=========================================== +Version 1.11.0, 1.13.0 (Latest) +=============================== To use the library, reference the Common API documentation alongside the framework specific API documentation. From 442227bdfcd852e07f0574dd94ad0b6614b12a08 Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Thu, 15 Dec 2022 13:22:09 -0800 Subject: [PATCH 066/526] feature: Doc update for TableFormatEnum (#3542) * Updated doc for table format Enum --- doc/api/prep_data/feature_store.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 0e9bf25586..838558c0a4 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -73,6 +73,10 @@ Inputs :members: :show-inheritance: +.. autoclass:: sagemaker.feature_store.inputs.TableFormatEnum + :members: + :show-inheritance: + Dataset Builder *************** From 146f6bbcc5ddec990e90fba6fcd4548781b7d994 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 16 Dec 2022 00:23:36 +0000 Subject: [PATCH 067/526] prepare release v2.124.0 --- CHANGELOG.md | 17 +++++++++++++++++ VERSION | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a05b64c96f..e5cd9826ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## v2.124.0 (2022-12-16) + +### Features + + * Doc update for TableFormatEnum + * Add p4de to smddp supported instance types + * Add disable_profiler field in config and propagate changes + * Added doc update for dataset builder + +### Bug Fixes and Other Changes + + * Use Async Inference Config when available for endpoint update + +### Documentation Changes + + * smdistributed libraries release notes + ## v2.123.0 (2022-12-15) ### Features diff --git a/VERSION b/VERSION index ea5085760e..67d5c2730e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.123.1.dev0 +2.124.0 From e07f94414385f6b513e249a61ab64b1664d49b42 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 16 Dec 2022 00:23:37 +0000 Subject: [PATCH 068/526] update development version to v2.124.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 67d5c2730e..97d160799c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.124.0 +2.124.1.dev0 From d1f965b5dc33526c93547c8380c77ba6ae0c32d5 Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Fri, 16 Dec 2022 11:09:28 +0100 Subject: [PATCH 069/526] fix: LF in all files + flake8 --- tests/integ/test_feature_store.py | 9 +- .../feature_store/test_feature_store.py | 139 ------------------ 2 files changed, 2 insertions(+), 146 deletions(-) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 3922de6350..e0230ca1e5 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -12,9 +12,9 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import datetime import json import time -import datetime from contextlib import contextmanager import boto3 @@ -23,18 +23,13 @@ import pytest from pandas import DataFrame +from sagemaker.feature_group_utils import get_feature_group_as_dataframe from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup -from sagemaker.feature_store.inputs import ( - FeatureValue, - FeatureParameter, - TableFormatEnum, -) from sagemaker.feature_store.feature_store import FeatureStore from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout -from sagemaker.feature_group_utils import get_feature_group_as_dataframe BUCKET_POLICY = { "Version": "2012-10-17", diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index c3ae6a2400..073daca9ea 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -10,8 +10,6 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. -"""Test for Feature Store""" from __future__ import absolute_import import datetime @@ -24,22 +22,6 @@ DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]}) -class PicklableMock(Mock): - """Mock class use for tests""" - - def __reduce__(self): - return (Mock, ()) - - -@pytest.fixture -def role_arn(): - return "arn:role" - - -@pytest.fixture -def s3_uri(): - return "s3://some/uri" - @pytest.fixture def sagemaker_session_mock(): @@ -152,124 +134,3 @@ def test_list_feature_groups_with_all_filters(sagemaker_session_mock): max_results=50, next_token="token", ) - - try: - manager.run(df) - except Exception as e: - assert "The config profile (non_exist) could not be found" in str(e) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - PicklableMock(return_value=[1]), -) -def test_ingestion_manager_run_multi_process_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=None, - max_workers=2, - max_processes=2, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1, 1, 1, 1] - assert manager.failed_rows == [1, 1, 1, 1] - - -@pytest.fixture -def query(sagemaker_session_mock): - return AthenaQuery( - catalog="catalog", - database="database", - table_name="table_name", - sagemaker_session=sagemaker_session_mock, - ) - - -def test_athena_query_run(sagemaker_session_mock, query): - WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run( - query_string="query", - output_location="s3://some-bucket/some-path", - workgroup=WORKGROUP, - ) - sagemaker_session_mock.start_query_execution.assert_called_with( - catalog="catalog", - database="database", - query_string="query", - output_location="s3://some-bucket/some-path", - kms_key=None, - workgroup=WORKGROUP, - ) - assert "some-bucket" == query._result_bucket - assert "some-path" == query._result_file_prefix - assert "query_id" == query._current_query_execution_id - - -def test_athena_query_wait(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") - - -def test_athena_query_get_query_execution(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -@patch("pandas.read_csv") -def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "SUCCEEDED"}} - } - query._current_query_execution_id = "query_id" - query._result_bucket = "bucket" - query._result_file_prefix = "prefix" - query.as_dataframe() - sagemaker_session_mock.download_athena_query_result.assert_called_with( - bucket="bucket", - prefix="prefix", - query_execution_id="query_id", - filename="tmp/query_id.csv", - ) - read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "FAILED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Failed to execute query query_id" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "QUEUED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "RUNNING"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) From f7407dfe307176a97baff993c15c9ce3fc7699b8 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 21 Dec 2022 07:24:02 +0000 Subject: [PATCH 070/526] prepare release v2.119.0 --- CHANGELOG.md | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37b3440f69..9d295148aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,63 @@ # Changelog +## v2.119.0 (2022-12-21) + +### Features + + * add RandomSeed to support reproducible HPO + * Doc update for TableFormatEnum + * Add p4de to smddp supported instance types + * Add disable_profiler field in config and propagate changes + * Added doc update for dataset builder + * Add support for TF2.9.2 training images + * Add SageMaker Experiment + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + * Algorithms Region Expansion OSU/DXB + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Do not specify S3 path for disabled profiler + * Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath + * Use Async Inference Config when available for endpoint update + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + * Update for Tensorflow Serving 2.11 inference DLCs + * Skip Bad Transform Test + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + * support idempotency for framework and spark processors + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + +### Documentation Changes + + * fix the incorrect property reference + * smdistributed libraries release notes + ## v2.125.0 (2022-12-19) ### Features diff --git a/VERSION b/VERSION index 1e80f372b6..23fe2bf317 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.125.1.dev0 +2.119.0 From 047d062638ee18ca28b11c5c88ff83fd42383dbe Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 21 Dec 2022 07:24:03 +0000 Subject: [PATCH 071/526] update development version to v2.119.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 23fe2bf317..dda4128cf2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.0 +2.119.1.dev0 From 48c13e94d449e1e58450bb71555a5f5b0860eb7d Mon Sep 17 00:00:00 2001 From: Alexander Shirkov <10080307+gradientsky@users.noreply.github.com> Date: Wed, 21 Dec 2022 05:42:15 -0800 Subject: [PATCH 072/526] feature: AutoGluon 0.6.1 image_uris (#3544) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: type hint of PySparkProcessor __init__ (#3297) * fix: fix PySparkProcessor __init__ params type (#3354) * fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524) * fix: type hint of PySparkProcessor __init__ (#3297) * fix: fix PySparkProcessor __init__ params type (#3354) * fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524) * feature: AutoGluon 0.6.1 image_uris * feature: AutoGluon 0.6.1 image_uris * Added processors for training images for consistency * Formatting * Formatting Co-authored-by: Kevin Co-authored-by: André Perez Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/image_uri_config/autogluon.json | 76 ++++++++++++++++++- .../sagemaker/image_uris/test_autogluon.py | 15 +++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 590b6e5f82..2fab95eefc 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -4,7 +4,8 @@ "version_aliases": { "0.3": "0.3.2", "0.4": "0.4.3", - "0.5": "0.5.2" + "0.5": "0.5.2", + "0.6": "0.6.1" }, "versions": { "0.3.1": { @@ -101,6 +102,7 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", + "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.4.2": { @@ -133,6 +135,7 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", + "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.4.3": { @@ -165,6 +168,7 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", + "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.5.2": { @@ -197,6 +201,39 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", + "processors": ["cpu", "gpu"], + "py_versions": ["py38"] + }, + "0.6.1": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-training", + "processors": ["cpu", "gpu"], "py_versions": ["py38"] } } @@ -205,7 +242,8 @@ "version_aliases": { "0.3": "0.3.2", "0.4": "0.4.3", - "0.5": "0.5.2" + "0.5": "0.5.2", + "0.6": "0.6.1" }, "versions": { "0.3.1": { @@ -435,6 +473,40 @@ "repository": "autogluon-inference", "processors": ["cpu", "gpu"], "py_versions": ["py38"] + }, + "0.6.1": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-inference", + "processors": ["cpu", "gpu"], + "py_versions": ["py38"] } } } diff --git a/tests/unit/sagemaker/image_uris/test_autogluon.py b/tests/unit/sagemaker/image_uris/test_autogluon.py index 7f7aea2850..d4b9690505 100644 --- a/tests/unit/sagemaker/image_uris/test_autogluon.py +++ b/tests/unit/sagemaker/image_uris/test_autogluon.py @@ -37,12 +37,25 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", "us-west-2": "763104351884", } -VERSIONS = ["0.3.1", "0.3.2", "0.4.0", "0.4.2", "0.4.3", "0.3", "0.4", "0.5.2", "0.5"] +VERSIONS = [ + "0.3.1", + "0.3.2", + "0.4.0", + "0.4.2", + "0.4.3", + "0.3", + "0.4", + "0.5.2", + "0.5", + "0.6.1", + "0.6", +] SCOPES = ["training", "inference"] PROCESSORS = ["cpu", "gpu"] From a14ebc39697b6ba9f6506c4456794253b80d63bf Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Wed, 21 Dec 2022 16:14:38 +0100 Subject: [PATCH 073/526] change: merged with latest master --- CHANGELOG.md | 68 ----------------- VERSION | 2 +- ...azon_sagemaker_model_building_pipeline.rst | 8 +- src/sagemaker/clarify.py | 30 ++++---- src/sagemaker/debugger/profiler_config.py | 6 +- src/sagemaker/image_uri_config/autogluon.json | 76 +------------------ .../model_monitor/clarify_model_monitoring.py | 23 +++--- .../model_monitor/model_monitoring.py | 16 ++-- src/sagemaker/session.py | 12 --- src/sagemaker/tuner.py | 18 ----- src/sagemaker/workflow/clarify_check_step.py | 4 +- tests/integ/test_feature_store.py | 6 -- .../sagemaker/image_uris/test_autogluon.py | 15 +--- tests/unit/test_session.py | 6 -- tests/unit/test_tuner.py | 1 - tests/unit/tuner_test_utils.py | 1 - 16 files changed, 43 insertions(+), 249 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d295148aa..e5cd9826ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,73 +1,5 @@ # Changelog -## v2.119.0 (2022-12-21) - -### Features - - * add RandomSeed to support reproducible HPO - * Doc update for TableFormatEnum - * Add p4de to smddp supported instance types - * Add disable_profiler field in config and propagate changes - * Added doc update for dataset builder - * Add support for TF2.9.2 training images - * Add SageMaker Experiment - * Feature Store dataset builder, delete_record, get_record, list_feature_group - * Add OSU region to frameworks for DLC - * Algorithms Region Expansion OSU/DXB - * Add Neo image uri config for Pytorch 1.12 - * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 - * Update registries with new region account number mappings. - * Add DXB region to frameworks by DLC - * Add Code Owners file - * Added transform with monitoring pipeline step in transformer - * Update TF 2.9 and TF 2.10 inference DLCs - * make estimator accept json file as modelparallel config - * SageMaker Training Compiler does not support p4de instances - * Add support for SparkML v3.3 - -### Bug Fixes and Other Changes - - * Do not specify S3 path for disabled profiler - * Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath - * Use Async Inference Config when available for endpoint update - * the Hyperband support fix for the HPO - * unpin packaging version - * Remove content type image/jpg from analysis configuration schema - * Update for Tensorflow Serving 2.11 inference DLCs - * Skip Bad Transform Test - * Pop out ModelPackageName from pipeline definition - * Fix failing jumpstart cache unit tests - * FrameworkProcessor S3 uploads - * Add constraints file for apache-airflow - * support idempotency for framework and spark processors - * Fix bug forcing uploaded tar to be named sourcedir - * Update local_requirements.txt PyYAML version - * refactoring : using with statement - * Allow Py 3.7 for MMS Test Docker env - * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` - * Update scipy to 1.7.3 to support M1 development envs - * Fixing type hints for Spark processor that has instance type/count params in reverse order - * Add DeepAR ap-northeast-3 repository. - * Fix AsyncInferenceConfig documentation typo - * fix ml_inf to ml_inf1 in Neo multi-version support - * Fix type annotations - * add neo mvp region accounts - -### Documentation Changes - - * fix the incorrect property reference - * smdistributed libraries release notes - -## v2.125.0 (2022-12-19) - -### Features - - * add RandomSeed to support reproducible HPO - -### Bug Fixes and Other Changes - - * Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath - ## v2.124.0 (2022-12-16) ### Features diff --git a/VERSION b/VERSION index dda4128cf2..97d160799c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.1.dev0 +2.124.1.dev0 diff --git a/doc/amazon_sagemaker_model_building_pipeline.rst b/doc/amazon_sagemaker_model_building_pipeline.rst index e3548f80f2..c1441d6340 100644 --- a/doc/amazon_sagemaker_model_building_pipeline.rst +++ b/doc/amazon_sagemaker_model_building_pipeline.rst @@ -453,7 +453,7 @@ Example: str_outputParam, int_outputParam, bool_outputParam, float_outputParam ], ) - output_ref = step_lambda.properties.Outputs["output1"] + output_ref = step_lambda.OutputParameters["output1"] Where the lambda function with :code:`arn arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda` should output like this: @@ -479,7 +479,7 @@ Note that the output parameters can not be nested. Otherwise, the value will be } } -This will be resolved as :code:`{"output1": "{\"nested_output1\":\"my-output\"}"}` by which if you refer :code:`step_lambda.properties.Outputs["output1"]["nested_output1"]` later, a non-retryable client error will be thrown. +This will be resolved as :code:`{"output1": "{\"nested_output1\":\"my-output\"}"}` by which if you refer :code:`step_lambda.OutputParameters["output1"]["nested_output1"]` later, a non-retryable client error will be thrown. CallbackStep ````````````` @@ -503,7 +503,7 @@ Example: inputs={"arg1": "foo", "arg2": 5, "arg3": param}, outputs=[outputParam], ) - output_ref = step_callback.properties.Outputs["output1] + output_ref = step_callback.OutputParameters["output1] The output parameters cannot be nested. If the values are nested, they will be treated as a single string value. For example, a nested output value of @@ -515,7 +515,7 @@ The output parameters cannot be nested. If the values are nested, they will be t } } -is resolved as :code:`{"output1": "{\"nested_output1\":\"my-output\"}"}`. If you try to refer to :code:`step_callback.properties.Outputs["output1"]["nested_output1"]` this will throw a non-retryable client error. +is resolved as :code:`{"output1": "{\"nested_output1\":\"my-output\"}"}`. If you try to refer to :code:`step_callback.OutputParameters["output1"]["nested_output1"]` this will throw a non-retryable client error. QualityCheckStep diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 18fed12042..f082679401 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -330,11 +330,11 @@ def __init__( s3_analysis_config_output_path (str): S3 prefix to store the analysis config output. If this field is None, then the ``s3_output_path`` will be used to store the ``analysis_config`` output. - label (str): Target attribute of the model required by bias metrics. Specified as - column name or index for CSV dataset or as JMESPath expression for JSONLines. + label (str): Target attribute of the model required by bias metrics. + Specified as column name or index for CSV dataset or as JSONPath for JSONLines. *Required parameter* except for when the input dataset does not contain the label. - features (List[str]): JMESPath expression to locate the feature columns for - bias metrics if the dataset format is JSONLines. + features (List[str]): JSONPath for locating the feature columns for bias metrics if the + dataset format is JSONLines. dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, ``"application/jsonlines"`` for JSONLines, and ``"application/x-parquet"`` for Parquet. @@ -716,11 +716,11 @@ def __init__( ``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``. Args: - label (str or int): Index or JMESPath expression to locate the prediction - in the model output. In case, this is a predicted label of the same type - as the label in the dataset, no further arguments need to be specified. - probability (str or int): Index or JMESPath expression to locate the predicted score(s) - in the model output. + label (str or int): Index or JSONPath location in the model output for the prediction. + In case, this is a predicted label of the same type as the label in the dataset, + no further arguments need to be specified. + probability (str or int): Index or JSONPath location in the model output + for the predicted score(s). probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the prediction to a boolean value. Default is ``0.5``. @@ -1645,9 +1645,9 @@ def run_explainability( You can request multiple methods at once by passing in a list of `~sagemaker.clarify.ExplainabilityConfig`. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JMESPath expression to locate the predicted scores in the model output. - This is not required if the model output is a single score. Alternatively, - it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. @@ -1774,9 +1774,9 @@ def run_bias_and_explainability( str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig` ): - Index or JMESPath expression to locate the predicted scores in the model output. - This is not required if the model output is a single score. Alternatively, - it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 3d29e15cdb..561de38b9f 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -90,11 +90,7 @@ def _to_request_dict(self): """ profiler_config_request = {} - if ( - self.s3_output_path is not None - and self.disable_profiler is not None - and self.disable_profiler is False - ): + if self.s3_output_path is not None: profiler_config_request["S3OutputPath"] = self.s3_output_path profiler_config_request["DisableProfiler"] = self.disable_profiler diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 2fab95eefc..590b6e5f82 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -4,8 +4,7 @@ "version_aliases": { "0.3": "0.3.2", "0.4": "0.4.3", - "0.5": "0.5.2", - "0.6": "0.6.1" + "0.5": "0.5.2" }, "versions": { "0.3.1": { @@ -102,7 +101,6 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", - "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.4.2": { @@ -135,7 +133,6 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", - "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.4.3": { @@ -168,7 +165,6 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", - "processors": ["cpu", "gpu"], "py_versions": ["py38"] }, "0.5.2": { @@ -201,39 +197,6 @@ "us-west-2": "763104351884" }, "repository": "autogluon-training", - "processors": ["cpu", "gpu"], - "py_versions": ["py38"] - }, - "0.6.1": { - "registries": { - "af-south-1": "626614931356", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-west-1": "763104351884", - "us-west-2": "763104351884" - }, - "repository": "autogluon-training", - "processors": ["cpu", "gpu"], "py_versions": ["py38"] } } @@ -242,8 +205,7 @@ "version_aliases": { "0.3": "0.3.2", "0.4": "0.4.3", - "0.5": "0.5.2", - "0.6": "0.6.1" + "0.5": "0.5.2" }, "versions": { "0.3.1": { @@ -473,40 +435,6 @@ "repository": "autogluon-inference", "processors": ["cpu", "gpu"], "py_versions": ["py38"] - }, - "0.6.1": { - "registries": { - "af-south-1": "626614931356", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-north-1": "763104351884", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "eu-south-1": "692866216735", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-west-1": "763104351884", - "us-west-2": "763104351884" - }, - "repository": "autogluon-inference", - "processors": ["cpu", "gpu"], - "py_versions": ["py38"] } } } diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 030de7c6db..1a788a0d53 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -842,8 +842,8 @@ def __init__(self, bias_config, headers=None, label=None): bias_config (sagemaker.clarify.BiasConfig): Config object related to bias configurations. headers (list[str]): A list of column names in the input dataset. - label (str): Target attribute for the model required by bias metrics. Specified as - column name or index for CSV dataset, or as JMESPath expression for JSONLines. + label (str): Target attribute for the model required by bias metrics. + Specified as column name or index for CSV dataset, or as JSONPath for JSONLines. """ self.analysis_config = bias_config.get_config() if headers is not None: @@ -889,10 +889,9 @@ def suggest_baseline( model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JMESPath expression to locate the predicted scores in the model output. - This is not required if the model output is a single score. Alternatively, - it can be an instance of ModelPredictedLabelConfig to provide more parameters - like label_headers. + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of ModelPredictedLabelConfig to provide more parameters like label_headers. wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). @@ -1303,12 +1302,12 @@ def __init__( Args: analysis_config (BiasAnalysisConfig or ExplainabilityAnalysisConfig): analysis config from configurations of the baselining job. - features_attribute (str): JMESPath expression to locate features in predictor request - payload. Only required when predictor content type is JSONlines. - inference_attribute (str): Index, header or JMESPath expression to locate predicted - label in predictor response payload. - probability_attribute (str): Index or JMESPath expression to locate probabilities or - scores in the model output for computing feature attribution. + features_attribute (str): JSONpath to locate features in predictor request payload. + Only required when predictor content type is JSONlines. + inference_attribute (str): Index, header or JSONpath to locate predicted label in + predictor response payload. + probability_attribute (str): Index or JSONpath location in the model output for + probabilities or scores to be used for explainability. probability_threshold_attribute (float): Value to indicate the threshold to select the binary label in the case of binary classification. Default is 0.5. """ diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 2f8266a43a..817d951255 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1061,13 +1061,12 @@ def _generate_env_map( dataset_format (dict): The format of the baseline_dataset. dataset_source_container_path (str): The path to the dataset source. inference_attribute (str): Index or JSONpath to locate predicted label(s). - Only used for ModelQualityMonitor. + Only used for ModelQualityMonitor, ModelBiasMonitor, and ModelExplainabilityMonitor probability_attribute (str or int): Index or JSONpath to locate probabilities. - Only used for ModelQualityMonitor. - ground_truth_attribute (str): Index to locate actual label(s). - Only used for ModelQualityMonitor. + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + ground_truth_attribute (str): Index or JSONpath to locate actual label(s). probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor. + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor Returns: dict: Dictionary of environment keys and values. @@ -2601,13 +2600,10 @@ def suggest_baseline( problem_type (str): The type of problem of this model quality monitoring. Valid values are "Regression", "BinaryClassification", "MulticlassClassification". inference_attribute (str): Index or JSONpath to locate predicted label(s). - Only used for ModelQualityMonitor. probability_attribute (str or int): Index or JSONpath to locate probabilities. - Only used for ModelQualityMonitor. - ground_truth_attribute (str): Index to locate actual label(s). - Only used for ModelQualityMonitor. + ground_truth_attribute (str): Index or JSONpath to locate actual label(s). probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor. + Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor post_analytics_processor_script (str): The path to the record post-analytics processor script. This can be a local path or an S3 uri. output_s3_uri (str): Desired S3 destination Destination of the constraint_violations diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5404978200..602cd1fd9f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2146,7 +2146,6 @@ def tune( # noqa: C901 use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, - random_seed=None, ): """Create an Amazon SageMaker hyperparameter tuning job. @@ -2227,9 +2226,6 @@ def tune( # noqa: C901 started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). - random_seed (int): An initial value used to initialize a pseudo-random number generator. - Setting a random seed will make the hyperparameter tuning search strategies to - produce more consistent configurations for the same tuning job. (default: ``None``). """ tune_request = { @@ -2242,7 +2238,6 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, - random_seed=random_seed, strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( @@ -2399,7 +2394,6 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, - random_seed=None, strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2418,9 +2412,6 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. - random_seed (int): An initial value used to initialize a pseudo-random number generator. - Setting a random seed will make the hyperparameter tuning search strategies to - produce more consistent configurations for the same tuning job. strategy_config (dict): A configuration for the hyperparameter tuning job optimisation strategy. @@ -2439,9 +2430,6 @@ def _map_tuning_config( "TrainingJobEarlyStoppingType": early_stopping_type, } - if random_seed is not None: - tuning_config["RandomSeed"] = random_seed - tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: tuning_config["HyperParameterTuningJobObjective"] = tuning_objective diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 45a6467c1f..9a694cbec9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -413,7 +413,6 @@ def __init__( strategy_config: Optional[StrategyConfig] = None, early_stopping_type: Union[str, PipelineVariable] = "Off", estimator_name: Optional[str] = None, - random_seed: Optional[int] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -471,9 +470,6 @@ def __init__( estimator_name (str): A unique name to identify an estimator within the hyperparameter tuning job, when more than one estimator is used with the same tuning job (default: None). - random_seed (int): An initial value used to initialize a pseudo-random number generator. - Setting a random seed will make the hyperparameter tuning search strategies to - produce more consistent configurations for the same tuning job. """ if hyperparameter_ranges is None or len(hyperparameter_ranges) == 0: raise ValueError("Need to specify hyperparameter ranges") @@ -520,7 +516,6 @@ def __init__( self.latest_tuning_job = None self.warm_start_config = warm_start_config self.early_stopping_type = early_stopping_type - self.random_seed = random_seed def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False): """Prepare the tuner instance for tuning (fit).""" @@ -1227,9 +1222,6 @@ def _prepare_init_params_from_job_description(cls, job_details): "base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]), } - if "RandomSeed" in tuning_config: - params["random_seed"] = tuning_config["RandomSeed"] - if "HyperParameterTuningJobObjective" in tuning_config: params["objective_metric_name"] = tuning_config["HyperParameterTuningJobObjective"][ "MetricName" @@ -1491,7 +1483,6 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato warm_start_type=warm_start_type, parents=all_parents ), early_stopping_type=self.early_stopping_type, - random_seed=self.random_seed, ) if len(self.estimator_dict) > 1: @@ -1517,7 +1508,6 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato max_parallel_jobs=self.max_parallel_jobs, warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents), early_stopping_type=self.early_stopping_type, - random_seed=self.random_seed, ) @classmethod @@ -1536,7 +1526,6 @@ def create( tags=None, warm_start_config=None, early_stopping_type="Off", - random_seed=None, ): """Factory method to create a ``HyperparameterTuner`` instance. @@ -1597,9 +1586,6 @@ def create( Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to. - random_seed (int): An initial value used to initialize a pseudo-random number generator. - Setting a random seed will make the hyperparameter tuning search strategies to - produce more consistent configurations for the same tuning job. Returns: sagemaker.tuner.HyperparameterTuner: a new ``HyperparameterTuner`` object that can @@ -1638,7 +1624,6 @@ def create( tags=tags, warm_start_config=warm_start_config, early_stopping_type=early_stopping_type, - random_seed=random_seed, ) for estimator_name in estimator_names[1:]: @@ -1790,9 +1775,6 @@ def _get_tuner_args(cls, tuner, inputs): "early_stopping_type": tuner.early_stopping_type, } - if tuner.random_seed is not None: - tuning_config["random_seed"] = tuner.random_seed - if tuner.strategy_config is not None: tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 22b6fc2051..9d350b01f3 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -132,8 +132,8 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig): model_config (ModelConfig): Config of the model and its endpoint to be created. explainability_config (SHAPConfig): Config of the specific explainability method. Currently, only SHAP is supported. - model_scores (str or int or ModelPredictedLabelConfig): Index or JMESPath expression - to locate the predicted scores in the model output (default: None). + model_scores (str or int or ModelPredictedLabelConfig): Index or JSONPath location + in the model output for the predicted scores to be explained (default: None). This is not required if the model output is a single score. Alternatively, an instance of ModelPredictedLabelConfig can be provided but this field CANNOT be any type of the `PipelineVariable`. diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 8d7eb6c932..e0230ca1e5 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -15,7 +15,6 @@ import datetime import json import time -import datetime from contextlib import contextmanager import boto3 @@ -89,11 +88,6 @@ def base_name(): return f"my-base-{int(time.time() * 10**7)}" -@pytest.fixture -def base_name(): - return f"my-base-{int(time.time() * 10**7)}" - - @pytest.fixture def offline_store_s3_uri(feature_store_session, region_name): bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}" diff --git a/tests/unit/sagemaker/image_uris/test_autogluon.py b/tests/unit/sagemaker/image_uris/test_autogluon.py index d4b9690505..7f7aea2850 100644 --- a/tests/unit/sagemaker/image_uris/test_autogluon.py +++ b/tests/unit/sagemaker/image_uris/test_autogluon.py @@ -37,25 +37,12 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", "us-west-2": "763104351884", } -VERSIONS = [ - "0.3.1", - "0.3.2", - "0.4.0", - "0.4.2", - "0.4.3", - "0.3", - "0.4", - "0.5.2", - "0.5", - "0.6.1", - "0.6", -] +VERSIONS = ["0.3.1", "0.3.2", "0.4.0", "0.4.2", "0.4.3", "0.3", "0.4", "0.5.2", "0.5"] SCOPES = ["training", "inference"] PROCESSORS = ["cpu", "gpu"] diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 119d08cef4..ec4a21cbc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -897,7 +897,6 @@ def test_train_pack_to_request(sagemaker_session): "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, "ParameterRanges": SAMPLE_PARAM_RANGES, "TrainingJobEarlyStoppingType": "Off", - "RandomSeed": 0, }, "TrainingJobDefinition": { "StaticHyperParameters": STATIC_HPs, @@ -990,7 +989,6 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", - random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1082,7 +1080,6 @@ def assert_create_tuning_job_request(**kwrags): "max_jobs": 100, "max_parallel_jobs": 5, "parameter_ranges": SAMPLE_PARAM_RANGES, - "random_seed": 0, }, training_config={ "static_hyperparameters": STATIC_HPs, @@ -1173,7 +1170,6 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", - random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1250,7 +1246,6 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", - random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1294,7 +1289,6 @@ def assert_create_tuning_job_request(**kwargs): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", - random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 7e556c7d23..9bbc882dfa 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -545,7 +545,6 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session assert tuner.strategy == "Bayesian" assert tuner.objective_type == "Minimize" assert tuner.early_stopping_type == "Off" - assert tuner.random_seed == 0 assert isinstance(tuner.estimator, PCA) assert tuner.estimator.role == ROLE diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index 5cf7ba2fc2..be0dba2ccd 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -112,7 +112,6 @@ ], }, "TrainingJobEarlyStoppingType": "Off", - "RandomSeed": 0, }, "HyperParameterTuningJobName": JOB_NAME, "TrainingJobDefinition": { From 8fbc1c62770837bcd2e87b8815f04b33300dc773 Mon Sep 17 00:00:00 2001 From: hballuru <113142824+hballuru@users.noreply.github.com> Date: Wed, 21 Dec 2022 17:19:26 -0600 Subject: [PATCH 074/526] tensorflow inference 2.10.1 release (#3547) --- CHANGELOG.md | 58 +++++++++++++++++++ VERSION | 2 +- .../image_uri_config/tensorflow.json | 38 +++++++++++- 3 files changed, 96 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37b3440f69..9d295148aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,63 @@ # Changelog +## v2.119.0 (2022-12-21) + +### Features + + * add RandomSeed to support reproducible HPO + * Doc update for TableFormatEnum + * Add p4de to smddp supported instance types + * Add disable_profiler field in config and propagate changes + * Added doc update for dataset builder + * Add support for TF2.9.2 training images + * Add SageMaker Experiment + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + * Algorithms Region Expansion OSU/DXB + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Do not specify S3 path for disabled profiler + * Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath + * Use Async Inference Config when available for endpoint update + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + * Update for Tensorflow Serving 2.11 inference DLCs + * Skip Bad Transform Test + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + * support idempotency for framework and spark processors + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + +### Documentation Changes + + * fix the incorrect property reference + * smdistributed libraries release notes + ## v2.125.0 (2022-12-19) ### Features diff --git a/VERSION b/VERSION index 1e80f372b6..dda4128cf2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.125.1.dev0 +2.119.1.dev0 diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6bb36057fa..cb206c31a4 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -307,7 +307,7 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0", + "2.10": "2.10.1", "2.11": "2.11.0" }, "versions": { @@ -1707,6 +1707,41 @@ }, "repository": "tensorflow-inference" }, + "2.10.1": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, "2.11.0": { "registries": { "af-south-1": "626614931356", @@ -3347,3 +3382,4 @@ } } } + From 507fada9de376a54f32927c59f91eeba7949088a Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Wed, 18 Jan 2023 15:54:13 +0100 Subject: [PATCH 075/526] fix: unexpected unindent in feature_group.as_dataframe --- src/sagemaker/feature_store/feature_group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index fe0dcd7d39..6995a81694 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -134,8 +134,8 @@ def as_dataframe(self, **pandas_read_csv_kwargs) -> DataFrame: Args: pandas_read_csv_kwargs: key arguments used for the method pandas.read_csv - to be able to have a better tuning on data. For more info - about this methods visit: + to be able to have a better tuning on data. For more info + about this methods visit: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html Returns: A pandas DataFrame contains the query result. From 11e2ee016690ea8532d8f901a5ceb6ee0ab428ec Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Tue, 24 Jan 2023 18:02:21 +0100 Subject: [PATCH 076/526] fix: clean old sphinx executions --- doc/api/utility/featuregroup_utils.rst | 7 + doc/doc_utils/pretrainedmodels.rst | 2396 ------------------------ src/sagemaker/feature_group_utils.py | 27 +- 3 files changed, 23 insertions(+), 2407 deletions(-) create mode 100644 doc/api/utility/featuregroup_utils.rst diff --git a/doc/api/utility/featuregroup_utils.rst b/doc/api/utility/featuregroup_utils.rst new file mode 100644 index 0000000000..a41fc6ee9d --- /dev/null +++ b/doc/api/utility/featuregroup_utils.rst @@ -0,0 +1,7 @@ +FeatureGroup Utilities +---------------------- + +.. automodule:: sagemaker.feature_group_utils + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/doc/doc_utils/pretrainedmodels.rst b/doc/doc_utils/pretrainedmodels.rst index bfefc56c81..e69de29bb2 100644 --- a/doc/doc_utils/pretrainedmodels.rst +++ b/doc/doc_utils/pretrainedmodels.rst @@ -1,2396 +0,0 @@ -.. _all-pretrained-models: - -.. |external-link| raw:: html - - - -================================================ -Built-in Algorithms with pre-trained Model Table -================================================ - - The SageMaker Python SDK uses model IDs and model versions to access the necessary - utilities for pre-trained models. This table serves to provide the core material plus - some extra information that can be useful in selecting the correct model ID and - corresponding parameters. - - If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. - We highly suggest pinning an exact model version however. - - These models are also available through the - `JumpStart UI in SageMaker Studio `__ - -.. list-table:: Available Models - :widths: 50 20 20 20 30 20 - :header-rows: 1 - :class: datatable - - * - Model ID - - Fine Tunable? - - Latest Version - - Min SDK Version - - Problem Type - - Source - * - autogluon-classification-ensemble - - True - - 1.1.1 - - 2.103.0 - - Classification - - `GluonCV `__ |external-link| - * - autogluon-regression-ensemble - - True - - 1.1.1 - - 2.103.0 - - Regression - - `GluonCV `__ |external-link| - * - catboost-classification-model - - True - - 1.2.7 - - 2.75.0 - - Classification - - `Catboost `__ |external-link| - * - catboost-regression-model - - True - - 1.2.7 - - 2.75.0 - - Regression - - `Catboost `__ |external-link| - * - huggingface-eqa-bert-base-cased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-base-multilingual-cased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-base-multilingual-uncased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-base-uncased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-large-cased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-large-cased-whole-word-masking - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-large-uncased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-bert-large-uncased-whole-word-masking - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-distilbert-base-cased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-distilbert-base-multilingual-cased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-distilbert-base-uncased - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-distilroberta-base - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-roberta-base - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-roberta-base-openai-detector - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-eqa-roberta-large - - True - - 1.0.2 - - 2.75.0 - - Question Answering - - `HuggingFace `__ |external-link| - * - huggingface-ner-distilbert-base-cased-finetuned-conll03-english - - False - - 1.1.0 - - 2.75.0 - - Named Entity Recognition - - `HuggingFace `__ |external-link| - * - huggingface-ner-distilbert-base-uncased-finetuned-conll03-english - - False - - 1.1.0 - - 2.75.0 - - Named Entity Recognition - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-base-cased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-base-multilingual-cased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-base-multilingual-uncased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-base-uncased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-large-cased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-large-cased-whole-word-masking - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-large-uncased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-bert-large-uncased-whole-word-masking - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-distilbert-base-cased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-distilbert-base-multilingual-cased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-distilbert-base-uncased - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-distilroberta-base - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-roberta-base - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-roberta-base-openai-detector - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-roberta-large - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-roberta-large-openai-detector - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-xlm-clm-ende-1024 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-xlm-mlm-ende-1024 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-xlm-mlm-enro-1024 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-xlm-mlm-tlm-xnli15-1024 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-spc-xlm-mlm-xnli15-1024 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `HuggingFace `__ |external-link| - * - huggingface-summarization-bart-large-cnn-samsum - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-bert-small2bert-small-finetuned-cnn-daily-mail-summarization - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-bigbird-pegasus-large-arxiv - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-bigbird-pegasus-large-pubmed - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-distilbart-cnn-12-6 - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-distilbart-cnn-6-6 - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-distilbart-xsum-1-1 - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-summarization-distilbart-xsum-12-3 - - False - - 1.1.0 - - 2.75.0 - - Text Summarization - - `HuggingFace `__ |external-link| - * - huggingface-textgeneration-bloom-1b1 - - False - - 1.0.1 - - 2.75.0 - - Text Generation - - `HuggingFace `__ |external-link| - * - huggingface-textgeneration-bloom-1b7 - - False - - 1.0.1 - - 2.75.0 - - Text Generation - - `HuggingFace `__ |external-link| - * - huggingface-textgeneration-bloom-560m - - False - - 1.0.1 - - 2.75.0 - - Text Generation - - `HuggingFace `__ |external-link| - * - huggingface-textgeneration-distilgpt2 - - False - - 1.2.1 - - 2.75.0 - - Text Generation - - `HuggingFace `__ |external-link| - * - huggingface-textgeneration-gpt2 - - False - - 1.2.1 - - 2.75.0 - - Text Generation - - `HuggingFace `__ |external-link| - * - huggingface-translation-opus-mt-en-es - - False - - 1.1.0 - - 2.75.0 - - Machine Translation - - `HuggingFace `__ |external-link| - * - huggingface-translation-opus-mt-en-vi - - False - - 1.1.0 - - 2.75.0 - - Machine Translation - - `HuggingFace `__ |external-link| - * - huggingface-translation-t5-base - - False - - 1.1.0 - - 2.75.0 - - Machine Translation - - `HuggingFace `__ |external-link| - * - huggingface-translation-t5-large - - False - - 1.1.0 - - 2.75.0 - - Machine Translation - - `HuggingFace `__ |external-link| - * - huggingface-translation-t5-small - - False - - 1.1.0 - - 2.75.0 - - Machine Translation - - `HuggingFace `__ |external-link| - * - huggingface-txt2img-stable-diffusion-v1-4 - - False - - 1.0.1 - - 2.75.0 - - Source - - `HuggingFace `__ |external-link| - * - lightgbm-classification-model - - True - - 1.2.6 - - 2.75.0 - - Classification - - `LightGBM `__ |external-link| - * - lightgbm-regression-model - - True - - 1.2.6 - - 2.75.0 - - Regression - - `LightGBM `__ |external-link| - * - model-txt2img-stabilityai-stable-diffusion-v1-4 - - False - - 1.0.0 - - 2.75.0 - - Source - - `HuggingFace `__ |external-link| - * - mxnet-is-mask-rcnn-fpn-resnet101-v1d-coco - - False - - 1.2.1 - - 2.100.0 - - Instance Segmentation - - `GluonCV `__ |external-link| - * - mxnet-is-mask-rcnn-fpn-resnet18-v1b-coco - - False - - 1.2.1 - - 2.100.0 - - Instance Segmentation - - `GluonCV `__ |external-link| - * - mxnet-is-mask-rcnn-fpn-resnet50-v1b-coco - - False - - 1.2.1 - - 2.100.0 - - Instance Segmentation - - `GluonCV `__ |external-link| - * - mxnet-is-mask-rcnn-resnet18-v1b-coco - - False - - 1.2.1 - - 2.100.0 - - Instance Segmentation - - `GluonCV `__ |external-link| - * - mxnet-od-faster-rcnn-fpn-resnet101-v1d-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-faster-rcnn-fpn-resnet50-v1b-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-faster-rcnn-resnet101-v1d-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-faster-rcnn-resnet50-v1b-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-faster-rcnn-resnet50-v1b-voc - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-300-vgg16-atrous-coco - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-300-vgg16-atrous-voc - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-mobilenet1-0-coco - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-mobilenet1-0-voc - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-resnet50-v1-coco - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-resnet50-v1-voc - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-vgg16-atrous-coco - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-ssd-512-vgg16-atrous-voc - - True - - 1.3.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-yolo3-darknet53-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-yolo3-darknet53-voc - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-yolo3-mobilenet1-0-coco - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-od-yolo3-mobilenet1-0-voc - - False - - 1.2.1 - - 2.100.0 - - Object Detection - - `GluonCV `__ |external-link| - * - mxnet-semseg-fcn-resnet101-ade - - True - - 1.4.1 - - 2.100.0 - - Semantic Segmentation - - `GluonCV `__ |external-link| - * - mxnet-semseg-fcn-resnet101-coco - - True - - 1.4.1 - - 2.100.0 - - Semantic Segmentation - - `GluonCV `__ |external-link| - * - mxnet-semseg-fcn-resnet101-voc - - True - - 1.4.1 - - 2.100.0 - - Semantic Segmentation - - `GluonCV `__ |external-link| - * - mxnet-semseg-fcn-resnet50-ade - - True - - 1.4.1 - - 2.100.0 - - Semantic Segmentation - - `GluonCV `__ |external-link| - * - mxnet-tcembedding-robertafin-base-uncased - - False - - 1.2.1 - - 2.100.0 - - Text Embedding - - `GluonCV `__ |external-link| - * - mxnet-tcembedding-robertafin-base-wiki-uncased - - False - - 1.2.1 - - 2.100.0 - - Text Embedding - - `GluonCV `__ |external-link| - * - mxnet-tcembedding-robertafin-large-uncased - - False - - 1.2.1 - - 2.100.0 - - Text Embedding - - `GluonCV `__ |external-link| - * - mxnet-tcembedding-robertafin-large-wiki-uncased - - False - - 1.2.1 - - 2.100.0 - - Text Embedding - - `GluonCV `__ |external-link| - * - pytorch-eqa-bert-base-cased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-base-multilingual-cased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-base-multilingual-uncased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-base-uncased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-cased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-cased-whole-word-masking - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-cased-whole-word-masking-finetuned-squad - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-uncased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-uncased-whole-word-masking - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-bert-large-uncased-whole-word-masking-finetuned-squad - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-distilbert-base-cased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-distilbert-base-multilingual-cased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-distilbert-base-uncased - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-distilroberta-base - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-roberta-base - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-roberta-base-openai-detector - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-roberta-large - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-eqa-roberta-large-openai-detector - - True - - 1.2.1 - - 2.75.0 - - Question Answering - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-alexnet - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-densenet121 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-densenet161 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-densenet169 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-densenet201 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-googlenet - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-mobilenet-v2 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnet101 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnet152 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnet18 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnet34 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnet50 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnext101-32x8d - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-resnext50-32x4d - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-shufflenet-v2-x1-0 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-squeezenet1-0 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-squeezenet1-1 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg11 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg11-bn - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg13 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg13-bn - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg16 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg16-bn - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg19 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-vgg19-bn - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-wide-resnet101-2 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-ic-wide-resnet50-2 - - True - - 2.2.4 - - 2.75.0 - - Image Classification - - `Pytorch Hub `__ |external-link| - * - pytorch-od-nvidia-ssd - - False - - 1.0.2 - - 2.75.0 - - Object Detection - - `Pytorch Hub `__ |external-link| - * - pytorch-od1-fasterrcnn-mobilenet-v3-large-320-fpn - - False - - 1.0.0 - - 2.75.0 - - Object Detection - - `Pytorch Hub `__ |external-link| - * - pytorch-od1-fasterrcnn-mobilenet-v3-large-fpn - - False - - 1.0.0 - - 2.75.0 - - Object Detection - - `Pytorch Hub `__ |external-link| - * - pytorch-od1-fasterrcnn-resnet50-fpn - - True - - 1.3.2 - - 2.75.0 - - Object Detection - - `Pytorch Hub `__ |external-link| - * - pytorch-tabtransformerclassification-model - - True - - 1.0.4 - - 2.75.0 - - Source - - `Source `__ |external-link| - * - pytorch-tabtransformerregression-model - - True - - 1.0.3 - - 2.75.0 - - Source - - `Source `__ |external-link| - * - pytorch-textgeneration1-alexa20b - - False - - 1.0.0 - - 2.116.0 - - Source - - `Source `__ |external-link| - * - sklearn-classification-linear - - True - - 1.1.2 - - 2.75.0 - - Classification - - `ScikitLearn `__ |external-link| - * - sklearn-regression-linear - - True - - 1.1.2 - - 2.75.0 - - Regression - - `ScikitLearn `__ |external-link| - * - tensorflow-audioembedding-frill-1 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-audioembedding-trill-3 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-audioembedding-trill-distilled-3 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-audioembedding-trillsson1-1 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-audioembedding-trillsson2-1 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-audioembedding-trillsson3-1 - - False - - 1.0.1 - - 2.80.0 - - Source - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r101x1-imagenet21k-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r101x3-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r101x3-imagenet21k-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r50x1-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r50x1-imagenet21k-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r50x3-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-m-r50x3-imagenet21k-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-s-r101x1-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-s-r50x1-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-bit-s-r50x3-ilsvrc2012-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b0-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b1-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b2-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b3-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b4-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b5-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b6-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-b7-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-lite0-classification-2 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-lite1-classification-2 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-lite2-classification-2 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-lite3-classification-2 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-efficientnet-lite4-classification-2 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-inception-resnet-v2-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-inception-v1-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-inception-v2-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-inception-v3-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-025-128-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-025-160-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-025-192-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-025-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-050-128-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-050-160-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-050-192-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-050-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-075-128-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-075-160-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-075-192-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-100-128-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-100-160-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-100-192-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v1-100-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-035-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-050-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-075-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-130-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-mobilenet-v2-140-224-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v1-101-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v1-152-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v1-50-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v2-101-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v2-152-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-imagenet-resnet-v2-50-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-resnet-50-classification-1 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-tf2-preview-inception-v3-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-ic-tf2-preview-mobilenet-v2-classification-4 - - True - - 2.0.5 - - 2.80.0 - - Image Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-m-r101x1-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-m-r101x3-imagenet21k-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-m-r50x1-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-m-r50x3-imagenet21k-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-s-r101x1-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-s-r101x3-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-s-r50x1-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-bit-s-r50x3-ilsvrc2012-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-b0-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-b1-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-b2-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-b3-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-b6-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-lite0-featurevector-2 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-lite1-featurevector-2 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-lite2-featurevector-2 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-lite3-featurevector-2 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-efficientnet-lite4-featurevector-2 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-inception-v1-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-inception-v2-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-inception-v3-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-025-128-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-025-160-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-025-192-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-025-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-050-128-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-050-160-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-050-192-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-050-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-075-128-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-075-160-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-075-192-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-075-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-100-128-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-100-160-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-100-192-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v1-100-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-035-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-050-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-075-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-100-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-130-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-mobilenet-v2-140-224-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v1-101-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v1-152-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v1-50-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v2-101-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v2-152-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-imagenet-resnet-v2-50-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-resnet-50-featurevector-1 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-tf2-preview-inception-v3-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-icembedding-tf2-preview-mobilenet-v2-featurevector-4 - - False - - 2.0.2 - - 2.80.0 - - Image Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-hourglass-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-hourglass-1024x1024-kpts-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-hourglass-512x512-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-hourglass-512x512-kpts-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-resnet101v1-fpn-512x512-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-resnet50v1-fpn-512x512-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-resnet50v1-fpn-512x512-kpts-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-resnet50v2-512x512-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-centernet-resnet50v2-512x512-kpts-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d0-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d1-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d2-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d3-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d4-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-efficientdet-d5-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-inception-resnet-v2-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-inception-resnet-v2-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet101-v1-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet101-v1-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet101-v1-800x1333-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet152-v1-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet152-v1-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet152-v1-800x1333-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet50-v1-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet50-v1-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-faster-rcnn-resnet50-v1-800x1333-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet101-v1-fpn-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet101-v1-fpn-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet152-v1-fpn-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet152-v1-fpn-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet50-v1-fpn-1024x1024-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-retinanet-resnet50-v1-fpn-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-ssd-mobilenet-v1-fpn-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-ssd-mobilenet-v2-2 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-ssd-mobilenet-v2-fpnlite-320x320-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od-ssd-mobilenet-v2-fpnlite-640x640-1 - - False - - 2.0.2 - - 2.80.0 - - Object Detection - - `Tensorflow Hub `__ |external-link| - * - tensorflow-od1-ssd-efficientdet-d0-512x512-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-efficientdet-d1-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-efficientdet-d2-768x768-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-efficientdet-d3-896x896-coco17-tpu-32 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-mobilenet-v1-fpn-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-mobilenet-v2-fpnlite-320x320-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-mobilenet-v2-fpnlite-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet101-v1-fpn-1024x1024-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet101-v1-fpn-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet152-v1-fpn-1024x1024-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet152-v1-fpn-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet50-v1-fpn-1024x1024-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-od1-ssd-resnet50-v1-fpn-640x640-coco17-tpu-8 - - True - - 1.0.2 - - 2.75.0 - - Object Detection - - `Source `__ |external-link| - * - tensorflow-spc-bert-en-cased-L-12-H-768-A-12-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-bert-en-uncased-L-12-H-768-A-12-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-bert-en-uncased-L-24-H-1024-A-16-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-bert-en-wwm-cased-L-24-H-1024-A-16-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-bert-multi-cased-L-12-H-768-A-12-2 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-electra-base-1 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-electra-small-1 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-experts-bert-pubmed-1 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-spc-experts-bert-wiki-books-1 - - True - - 1.2.3 - - 2.75.0 - - Sentence Pair Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-albert-en-base - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-cased-L-12-H-768-A-12-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-cased-L-24-H-1024-A-16-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-uncased-L-24-H-1024-A-16-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-wwm-cased-L-24-H-1024-A-16-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-bert-multi-cased-L-12-H-768-A-12-2 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-electra-base-1 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-electra-small-1 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-experts-bert-pubmed-1 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-experts-bert-wiki-books-1 - - True - - 2.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-10-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-12-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-2-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-4-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-6-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-128-A-2 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-256-A-4 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-512-A-8 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-small-bert-bert-en-uncased-L-8-H-768-A-12 - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-talking-heads-base - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tc-talking-heads-large - - True - - 1.0.1 - - 2.80.0 - - Text Classification - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-10-H-128-A-2-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-10-H-256-A-4-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-10-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-10-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-12-H-128-A-2-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-12-H-256-A-4 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-12-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-4 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-2-H-128-A-2-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-2-H-256-A-4 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-2-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-2-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-4-H-128-A-2-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-4-H-256-A-4-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-4-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-4-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-6-H-128-A-2-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-6-H-256-A-4 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-6-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-6-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-8-H-256-A-4-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-8-H-512-A-8-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-en-uncased-L-8-H-768-A-12-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-wiki-books-mnli-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-bert-wiki-books-sst2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-talkheads-ggelu-bert-en-base-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-talkheads-ggelu-bert-en-large-2 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-base-1 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-large-1 - - False - - 1.1.1 - - 2.75.0 - - Text Embedding - - `Tensorflow Hub `__ |external-link| - * - xgboost-classification-model - - True - - 1.2.3 - - 2.75.0 - - Classification - - `XGBoost `__ |external-link| - * - xgboost-regression-model - - True - - 1.2.3 - - 2.75.0 - - Regression - - `XGBoost `__ |external-link| diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index af3bf6327d..7f4112b224 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -29,8 +29,8 @@ logger = logging.getLogger(__name__) -def _get_session_from_role(role: str, region: str): - """Method use to get the sagemaker session from a role and a region. +def _get_session_from_role(role: str, region: str) -> Session: + """Method use to get the :class:~`sagemaker.session.Session` from a role and a region. Helpful in case it's invoke from a session with a role without permission it can assume another role temporarily to perform certain taks. @@ -89,7 +89,8 @@ def get_feature_group_as_dataframe( verbose: bool = True, **pandas_read_csv_kwargs, ) -> DataFrame: - """Get a feature group as a pandas.DataFrame + """Get a :class:~`sagemaker.feature_store.feature_group.FeatureGroup` + as a pandas.DataFrame Description: Method to run an athena query over a Feature Group in a Feature Store @@ -98,16 +99,17 @@ def get_feature_group_as_dataframe( with the data. Args: - region (str): region of the target feature store + region (str): region of the target Feature Store feature_group_name (str): feature store name query (str): query to run. By default, it will take the latest ingest with data that wasn't deleted. If latest_ingestion is False it will take all the data in the feature group that wasn't deleted. It needs to use the keyword - "#{table}" to refer to the table. e.g.: - 'SELECT * FROM "sagemaker_featurestore"."#{table}"' - athena_bucket (str): S3 bucket for running the query + "#{table}" to refer to the FeatureGroup name. e.g.: + 'SELECT * FROM "sagemaker_featurestore"."#{table}"' + athena_bucket (str): Amazon S3 bucket for running the query role (str): role of the account used to extract data from feature store - session (str): session of SageMaker used to work with the feature store + session (str): :class:~`sagemaker.session.Session` + of SageMaker used to work with the feature store event_time_feature_name (str): eventTimeId feature. Mandatory only if the latest ingestion is True latest_ingestion (bool): if True it will get the data only from the latest ingestion. @@ -173,7 +175,8 @@ def get_feature_group_as_dataframe( def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: - """Format the column names for a FeatureGroup + """Formats the column names in a valid way for + :class:~`sagemaker.feature_store.feature_group.FeatureGroup` Module to format correctly the name of the columns of a DataFrame to later generate the features names of a Feature Group @@ -216,7 +219,8 @@ def prepare_fg_from_dataframe_or_file( verbose: bool = False, **pandas_read_csv_kwargs, ) -> FeatureGroup: - """Module to prepare a dataframe before creating Feature + """Module to prepare a dataframe before creating a + :class:~`sagemaker.feature_store.feature_group.FeatureGroup` Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to a file with proper dtypes, feature names and mandatory features (record_id, @@ -240,7 +244,8 @@ def prepare_fg_from_dataframe_or_file( session (str): session of SageMaker used to work with the feature store Returns: - FeatureGroup: FG prepared with all the methods and definitions properly defined + :class:~`sagemaker.feature_store.feature_group.FeatureGroup`: + FG prepared with all the methods and definitions properly defined """ logger.setLevel(logging.WARNING) From a7309e336a36d2d4bfab990031615c27e311e386 Mon Sep 17 00:00:00 2001 From: "jose-juan.pena-gomez@capgemini.com" Date: Tue, 24 Jan 2023 18:52:18 +0100 Subject: [PATCH 077/526] fix: sphinx class reference --- src/sagemaker/feature_group_utils.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 7f4112b224..b54f225ba5 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -30,7 +30,7 @@ def _get_session_from_role(role: str, region: str) -> Session: - """Method use to get the :class:~`sagemaker.session.Session` from a role and a region. + """Method use to get the ":class:~`sagemaker.session.Session`" from a role and a region. Helpful in case it's invoke from a session with a role without permission it can assume another role temporarily to perform certain taks. @@ -89,8 +89,7 @@ def get_feature_group_as_dataframe( verbose: bool = True, **pandas_read_csv_kwargs, ) -> DataFrame: - """Get a :class:~`sagemaker.feature_store.feature_group.FeatureGroup` - as a pandas.DataFrame + """Get a ":class:~`sagemaker.feature_store.feature_group.FeatureGroup`" as a pandas.DataFrame Description: Method to run an athena query over a Feature Group in a Feature Store @@ -108,7 +107,7 @@ def get_feature_group_as_dataframe( 'SELECT * FROM "sagemaker_featurestore"."#{table}"' athena_bucket (str): Amazon S3 bucket for running the query role (str): role of the account used to extract data from feature store - session (str): :class:~`sagemaker.session.Session` + session (str): ":class:"~`sagemaker.session.Session`" of SageMaker used to work with the feature store event_time_feature_name (str): eventTimeId feature. Mandatory only if the latest ingestion is True @@ -175,8 +174,7 @@ def get_feature_group_as_dataframe( def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: - """Formats the column names in a valid way for - :class:~`sagemaker.feature_store.feature_group.FeatureGroup` + """Formats the column names in a valid way for ":class:~`sagemaker.feature_store.feature_group.FeatureGroup`" Module to format correctly the name of the columns of a DataFrame to later generate the features names of a Feature Group @@ -220,7 +218,7 @@ def prepare_fg_from_dataframe_or_file( **pandas_read_csv_kwargs, ) -> FeatureGroup: """Module to prepare a dataframe before creating a - :class:~`sagemaker.feature_store.feature_group.FeatureGroup` + ":class:"~`sagemaker.feature_store.feature_group.FeatureGroup`" Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame or a path to a file with proper dtypes, feature names and mandatory features (record_id, @@ -244,8 +242,8 @@ def prepare_fg_from_dataframe_or_file( session (str): session of SageMaker used to work with the feature store Returns: - :class:~`sagemaker.feature_store.feature_group.FeatureGroup`: - FG prepared with all the methods and definitions properly defined + ":class:"~`sagemaker.feature_store.feature_group.FeatureGroup`": FG prepared with all + the methods and definitions properly defined """ logger.setLevel(logging.WARNING) From c22540da775092fe568d575589724b5bdbbcfde1 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Tue, 24 Jan 2023 19:33:09 +0100 Subject: [PATCH 078/526] fix: unexpected indent --- src/sagemaker/feature_group_utils.py | 29 +++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index b54f225ba5..334f263b3f 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -30,7 +30,7 @@ def _get_session_from_role(role: str, region: str) -> Session: - """Method use to get the ":class:~`sagemaker.session.Session`" from a role and a region. + """Method use to get the :class:`sagemaker.session.Session` from a role and a region. Helpful in case it's invoke from a session with a role without permission it can assume another role temporarily to perform certain taks. @@ -89,7 +89,7 @@ def get_feature_group_as_dataframe( verbose: bool = True, **pandas_read_csv_kwargs, ) -> DataFrame: - """Get a ":class:~`sagemaker.feature_store.feature_group.FeatureGroup`" as a pandas.DataFrame + """Get a :class:`sagemaker.feature_store.feature_group.FeatureGroup` as a pandas.DataFrame Description: Method to run an athena query over a Feature Group in a Feature Store @@ -107,7 +107,7 @@ def get_feature_group_as_dataframe( 'SELECT * FROM "sagemaker_featurestore"."#{table}"' athena_bucket (str): Amazon S3 bucket for running the query role (str): role of the account used to extract data from feature store - session (str): ":class:"~`sagemaker.session.Session`" + session (str): :class:`sagemaker.session.Session` of SageMaker used to work with the feature store event_time_feature_name (str): eventTimeId feature. Mandatory only if the latest ingestion is True @@ -174,10 +174,12 @@ def get_feature_group_as_dataframe( def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: - """Formats the column names in a valid way for ":class:~`sagemaker.feature_store.feature_group.FeatureGroup`" + """Formats the column names in a valid way for + :class:`sagemaker.feature_store.feature_group.FeatureGroup` - Module to format correctly the name of the columns of a DataFrame - to later generate the features names of a Feature Group + Description: + Module to format correctly the name of the columns of a DataFrame + to later generate the features names of a Feature Group Args: data (pandas.DataFrame): dataframe used @@ -218,13 +220,14 @@ def prepare_fg_from_dataframe_or_file( **pandas_read_csv_kwargs, ) -> FeatureGroup: """Module to prepare a dataframe before creating a - ":class:"~`sagemaker.feature_store.feature_group.FeatureGroup`" + :class:`sagemaker.feature_store.feature_group.FeatureGroup` - Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame - or a path to a file with proper dtypes, feature names and mandatory features (record_id, - event_id). It needs the sagemaker.Session linked to a role or the role and region used - to work Feature Stores. If record_id or event_id are not specified it will create ones - by default with the names 'record_id' and 'data_as_of_date'. + Description: + Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame + or a path to a file with proper dtypes, feature names and mandatory features (record_id, + event_id). It needs the sagemaker.Session linked to a role or the role and region used + to work Feature Stores. If record_id or event_id are not specified it will create ones + by default with the names 'record_id' and 'data_as_of_date'. Args: feature_group_name (str): feature group name @@ -242,7 +245,7 @@ def prepare_fg_from_dataframe_or_file( session (str): session of SageMaker used to work with the feature store Returns: - ":class:"~`sagemaker.feature_store.feature_group.FeatureGroup`": FG prepared with all + :class:`sagemaker.feature_store.feature_group.FeatureGroup`: FG prepared with all the methods and definitions properly defined """ From dcea6d49da0fedc017d6e36d9cc383d8637ea736 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Tue, 24 Jan 2023 19:51:00 +0100 Subject: [PATCH 079/526] fix: docstyle D205 error --- src/sagemaker/feature_group_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 334f263b3f..1b381fc6ae 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -174,8 +174,7 @@ def get_feature_group_as_dataframe( def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame: - """Formats the column names in a valid way for - :class:`sagemaker.feature_store.feature_group.FeatureGroup` + """Formats the column names for :class:`sagemaker.feature_store.feature_group.FeatureGroup` Description: Module to format correctly the name of the columns of a DataFrame @@ -219,8 +218,7 @@ def prepare_fg_from_dataframe_or_file( verbose: bool = False, **pandas_read_csv_kwargs, ) -> FeatureGroup: - """Module to prepare a dataframe before creating a - :class:`sagemaker.feature_store.feature_group.FeatureGroup` + """Prepares a dataframe to create a :class:`sagemaker.feature_store.feature_group.FeatureGroup` Description: Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame From e1f6a61cad5d58c5b5f2705edbb48b1e167cf89e Mon Sep 17 00:00:00 2001 From: jpenagom Date: Mon, 13 Feb 2023 11:40:30 +0100 Subject: [PATCH 080/526] change: line separator LF --- tests/integ/test_feature_store.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index f6472d0817..449178ddc3 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -935,8 +935,7 @@ def test_get_feature_group_with_session( latest_ingestion=True, athena_bucket=f"{offline_store_s3_uri}/query", low_memory=False, - ) # Using kwargs to pass a parameter to - # pandas.read_csv + ) # Using kwargs to pass a parameter to pandas.read_csv assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") assert not dataset.empty From ea2ba6c8b33e61974b3b3acfcbf93934168db295 Mon Sep 17 00:00:00 2001 From: jpenagom Date: Mon, 13 Feb 2023 11:42:39 +0100 Subject: [PATCH 081/526] change: removed unused param --- src/sagemaker/session.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b5efcf6858..21715e5820 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2199,7 +2199,6 @@ def tune( # noqa: C901 use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, - random_seed=None, environment=None, ): """Create an Amazon SageMaker hyperparameter tuning job. @@ -2281,9 +2280,6 @@ def tune( # noqa: C901 started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). - random_seed (int): An initial value used to initialize a pseudo-random number generator. - Setting a random seed will make the hyperparameter tuning search strategies to - produce more consistent configurations for the same tuning job. (default: ``None``). environment (dict[str, str]) : Environment variables to be set for use during training jobs (default: ``None``) """ From cf88f5337e641d4274f468cb6f7ed0d1cfea0ec6 Mon Sep 17 00:00:00 2001 From: Jose Juan Pena Date: Fri, 3 Mar 2023 12:38:53 +0100 Subject: [PATCH 082/526] fix: athena query to Feature Group --- src/sagemaker/feature_group_utils.py | 87 ++++++++++++++-------------- 1 file changed, 43 insertions(+), 44 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 1b381fc6ae..1626b5cbdd 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -29,37 +29,38 @@ logger = logging.getLogger(__name__) -def _get_session_from_role(role: str, region: str) -> Session: +def _get_session_from_role(region: str, role: str = None) -> Session: """Method use to get the :class:`sagemaker.session.Session` from a role and a region. Helpful in case it's invoke from a session with a role without permission it can assume - another role temporarily to perform certain taks. - + another role temporarily to perform certain tasks. Args: role: role name region: region name Returns: - + :class:`sagemaker.session` """ boto_session = boto3.Session(region_name=region) - sts = boto_session.client( - "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" - ) + # It will try to assume the role specified + if role: + sts = boto_session.client( + "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" + ) - metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") + metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") - access_key_id = metadata["Credentials"]["AccessKeyId"] - secret_access_key = metadata["Credentials"]["SecretAccessKey"] - session_token = metadata["Credentials"]["SessionToken"] + access_key_id = metadata["Credentials"]["AccessKeyId"] + secret_access_key = metadata["Credentials"]["SecretAccessKey"] + session_token = metadata["Credentials"]["SessionToken"] - boto_session = boto3.session.Session( - region_name=region, - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - aws_session_token=session_token, - ) + boto_session = boto3.session.Session( + region_name=region, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + aws_session_token=session_token, + ) # Sessions sagemaker_client = boto_session.client("sagemaker") @@ -76,18 +77,16 @@ def _get_session_from_role(role: str, region: str) -> Session: def get_feature_group_as_dataframe( - feature_group_name: str, - athena_bucket: str, - query: str = str( - "SELECT * FROM " '"sagemaker_featurestore"."#{table}" ' "WHERE is_deleted=False" - ), - role: str = None, - region: str = None, - session=None, - event_time_feature_name: str = None, - latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs, + feature_group_name: str, + athena_bucket: str, + query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE is_deleted=False """, + role: str = None, + region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs, ) -> DataFrame: """Get a :class:`sagemaker.feature_store.feature_group.FeatureGroup` as a pandas.DataFrame @@ -127,9 +126,9 @@ def get_feature_group_as_dataframe( if latest_ingestion: if event_time_feature_name is not None: query += str( - f"AND {event_time_feature_name}=(SELECT " - f"MAX({event_time_feature_name}) FROM " - + f'"sagemaker_featurestore"."{feature_group_name}")' + f"AND {event_time_feature_name}=(SELECT " + + f"MAX({event_time_feature_name}) FROM " + + '"sagemaker_featurestore"."#{table}")' ) else: exc = Exception( @@ -143,7 +142,7 @@ def get_feature_group_as_dataframe( if session is not None: sagemaker_session = session elif role is not None and region is not None: - sagemaker_session = _get_session_from_role(role=role, region=region) + sagemaker_session = _get_session_from_role(region=region) else: exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) @@ -208,15 +207,15 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: def prepare_fg_from_dataframe_or_file( - dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, - region: str = None, - session=None, - record_id: str = "record_id", - event_id: str = "data_as_of_date", - verbose: bool = False, - **pandas_read_csv_kwargs, + dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, + region: str = None, + session=None, + record_id: str = "record_id", + event_id: str = "data_as_of_date", + verbose: bool = False, + **pandas_read_csv_kwargs ) -> FeatureGroup: """Prepares a dataframe to create a :class:`sagemaker.feature_store.feature_group.FeatureGroup` @@ -228,6 +227,7 @@ def prepare_fg_from_dataframe_or_file( by default with the names 'record_id' and 'data_as_of_date'. Args: + **pandas_read_csv_kwargs (object): feature_group_name (str): feature group name dataframe_or_path (str, Path, pandas.DataFrame) : pandas.DataFrame or path to the data verbose (bool) : True for displaying messages, False for silent method. @@ -292,13 +292,12 @@ def prepare_fg_from_dataframe_or_file( import time current_time_sec = int(round(time.time())) - data[event_id] = Series([current_time_sec] * lg_id, dtype="float64") if session is not None: sagemaker_session = session elif role is not None and region is not None: - sagemaker_session = _get_session_from_role(role=role, region=region) + sagemaker_session = _get_session_from_role(region=region) else: exc = Exception("Argument Session or role and region must be specified.") logger.exception(exc) From 6b2c69d3a689a27cc153345a7ee20096fc7e95ef Mon Sep 17 00:00:00 2001 From: Jose Juan Pena Date: Fri, 3 Mar 2023 13:33:05 +0100 Subject: [PATCH 083/526] fix: linting and tox tests --- src/sagemaker/feature_group_utils.py | 43 ++++++++++++++-------------- src/sagemaker/session.py | 6 ++-- src/sagemaker/tuner.py | 8 +++--- 3 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index 1626b5cbdd..f37c70f648 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -77,16 +77,17 @@ def _get_session_from_role(region: str, role: str = None) -> Session: def get_feature_group_as_dataframe( - feature_group_name: str, - athena_bucket: str, - query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}" WHERE is_deleted=False """, - role: str = None, - region: str = None, - session=None, - event_time_feature_name: str = None, - latest_ingestion: bool = True, - verbose: bool = True, - **pandas_read_csv_kwargs, + feature_group_name: str, + athena_bucket: str, + query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}" + WHERE is_deleted=False """, + role: str = None, + region: str = None, + session=None, + event_time_feature_name: str = None, + latest_ingestion: bool = True, + verbose: bool = True, + **pandas_read_csv_kwargs, ) -> DataFrame: """Get a :class:`sagemaker.feature_store.feature_group.FeatureGroup` as a pandas.DataFrame @@ -126,8 +127,8 @@ def get_feature_group_as_dataframe( if latest_ingestion: if event_time_feature_name is not None: query += str( - f"AND {event_time_feature_name}=(SELECT " + - f"MAX({event_time_feature_name}) FROM " + + f"AND {event_time_feature_name}=(SELECT " + f"MAX({event_time_feature_name}) FROM " '"sagemaker_featurestore"."#{table}")' ) else: @@ -207,15 +208,15 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: def prepare_fg_from_dataframe_or_file( - dataframe_or_path: Union[str, Path, pandas.DataFrame], - feature_group_name: str, - role: str = None, - region: str = None, - session=None, - record_id: str = "record_id", - event_id: str = "data_as_of_date", - verbose: bool = False, - **pandas_read_csv_kwargs + dataframe_or_path: Union[str, Path, pandas.DataFrame], + feature_group_name: str, + role: str = None, + region: str = None, + session=None, + record_id: str = "record_id", + event_id: str = "data_as_of_date", + verbose: bool = False, + **pandas_read_csv_kwargs, ) -> FeatureGroup: """Prepares a dataframe to create a :class:`sagemaker.feature_store.feature_group.FeatureGroup` diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 57f277f139..e74b2f0f35 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2479,6 +2479,7 @@ def _map_tuning_config( objective_metric_name=None, parameter_ranges=None, strategy_config=None, + random_seed=None, completion_criteria_config=None, ): """Construct tuning job configuration dictionary. @@ -2518,14 +2519,13 @@ def _map_tuning_config( }, "TrainingJobEarlyStoppingType": early_stopping_type, } - - + if max_runtime_in_seconds is not None: tuning_config["ResourceLimits"]["MaxRuntimeInSeconds"] = max_runtime_in_seconds if random_seed is not None: tuning_config["RandomSeed"] = random_seed - + tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: tuning_config["HyperParameterTuningJobObjective"] = tuning_objective diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 6e7e75ffca..e7ce8be7e4 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -608,6 +608,7 @@ def __init__( completion_criteria_config: Optional[TuningJobCompletionCriteriaConfig] = None, early_stopping_type: Union[str, PipelineVariable] = "Off", estimator_name: Optional[str] = None, + random_seed: Optional[int] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -1448,7 +1449,7 @@ def _prepare_init_params_from_job_description(cls, job_details): "early_stopping_type": tuning_config["TrainingJobEarlyStoppingType"], "base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]), } - + if "TuningJobCompletionCriteria" in tuning_config: params["completion_criteria_config"] = TuningJobCompletionCriteriaConfig.from_job_desc( tuning_config["TuningJobCompletionCriteria"] @@ -1461,7 +1462,7 @@ def _prepare_init_params_from_job_description(cls, job_details): if "RandomSeed" in tuning_config: params["random_seed"] = tuning_config["RandomSeed"] - + if "HyperParameterTuningJobObjective" in tuning_config: params["objective_metric_name"] = tuning_config["HyperParameterTuningJobObjective"][ "MetricName" @@ -2026,13 +2027,12 @@ def _get_tuner_args(cls, tuner, inputs): "early_stopping_type": tuner.early_stopping_type, } - if tuner.max_runtime_in_seconds is not None: tuning_config["max_runtime_in_seconds"] = tuner.max_runtime_in_seconds if tuner.random_seed is not None: tuning_config["random_seed"] = tuner.random_seed - + if tuner.strategy_config is not None: tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() From dfaea77f152a0e6ecc489403046a500e105d0e86 Mon Sep 17 00:00:00 2001 From: JoseJuan98 Date: Sat, 4 Mar 2023 13:32:12 +0100 Subject: [PATCH 084/526] feature: feature group utils to facilitate development --- src/sagemaker/feature_group_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/feature_group_utils.py b/src/sagemaker/feature_group_utils.py index f37c70f648..4956336a21 100644 --- a/src/sagemaker/feature_group_utils.py +++ b/src/sagemaker/feature_group_utils.py @@ -29,13 +29,13 @@ logger = logging.getLogger(__name__) -def _get_session_from_role(region: str, role: str = None) -> Session: +def _get_session_from_role(region: str, assume_role: str = None) -> Session: """Method use to get the :class:`sagemaker.session.Session` from a role and a region. Helpful in case it's invoke from a session with a role without permission it can assume another role temporarily to perform certain tasks. Args: - role: role name + assume_role: role name region: region name Returns: @@ -44,12 +44,12 @@ def _get_session_from_role(region: str, role: str = None) -> Session: boto_session = boto3.Session(region_name=region) # It will try to assume the role specified - if role: + if assume_role: sts = boto_session.client( "sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com" ) - metadata = sts.assume_role(RoleArn=role, RoleSessionName="SagemakerExecution") + metadata = sts.assume_role(RoleArn=assume_role, RoleSessionName="SagemakerExecution") access_key_id = metadata["Credentials"]["AccessKeyId"] secret_access_key = metadata["Credentials"]["SecretAccessKey"] @@ -80,7 +80,7 @@ def get_feature_group_as_dataframe( feature_group_name: str, athena_bucket: str, query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}" - WHERE is_deleted=False """, + WHERE is_deleted=False """, role: str = None, region: str = None, session=None, From eface0f9a21abcae2f7174f61d492ea6cedc83d3 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Thu, 16 Jun 2022 11:08:03 -0700 Subject: [PATCH 085/526] fix: integs fallback from p3 to p2 instance (#3168) --- tests/conftest.py | 1 + tests/integ/__init__.py | 1 + tests/integ/test_horovod.py | 6 ++++-- tests/integ/test_horovod_mx.py | 6 ++++-- tests/integ/test_huggingface.py | 5 +++-- tests/integ/test_tf.py | 38 +++++++++++++++++++++++++++++---- 6 files changed, 47 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5e57eeb2a3..8ccf443133 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,7 @@ "ca-central-1", # it has p3, but not enough "eu-central-1", # it has p3, but not enough "eu-north-1", + "eu-west-1", # it has p3, but not enough "eu-west-2", # it has p3, but not enough "eu-west-3", "eu-south-1", diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index a4f0ac71b9..feca522eb5 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -74,6 +74,7 @@ "ca-central-1", # it has p3, but not enough "eu-central-1", # it has p3, but not enough "eu-north-1", + "eu-west-1", # it has p3, but not enough "eu-west-2", # it has p3, but not enough "eu-west-3", "eu-south-1", diff --git a/tests/integ/test_horovod.py b/tests/integ/test_horovod.py index 1615d3998e..7be3bc1abd 100644 --- a/tests/integ/test_horovod.py +++ b/tests/integ/test_horovod.py @@ -22,6 +22,7 @@ import sagemaker.utils import tests.integ as integ +from tests.integ.utils import gpu_list, retry_with_instance_list from sagemaker.tensorflow import TensorFlow from tests.integ import timeout @@ -51,18 +52,19 @@ def test_hvd_cpu( and integ.test_region() in integ.TRAINING_NO_P3_REGIONS, reason="no ml.p2 or ml.p3 instances in this region", ) +@retry_with_instance_list(gpu_list(integ.test_region())) def test_hvd_gpu( sagemaker_session, tensorflow_training_latest_version, tensorflow_training_latest_py_version, - gpu_instance_type, tmpdir, + **kwargs, ): _create_and_fit_estimator( sagemaker_session, tensorflow_training_latest_version, tensorflow_training_latest_py_version, - gpu_instance_type, + kwargs["instance_type"], tmpdir, ) diff --git a/tests/integ/test_horovod_mx.py b/tests/integ/test_horovod_mx.py index 1272690e1b..eba48b2f8d 100644 --- a/tests/integ/test_horovod_mx.py +++ b/tests/integ/test_horovod_mx.py @@ -24,6 +24,7 @@ import tests.integ as integ from sagemaker.mxnet import MXNet from tests.integ import timeout +from tests.integ.utils import gpu_list, retry_with_instance_list horovod_dir = os.path.join(os.path.dirname(__file__), "..", "data", "horovod") @@ -51,18 +52,19 @@ def test_hvd_cpu( and integ.test_region() in integ.TRAINING_NO_P3_REGIONS, reason="no ml.p2 or ml.p3 instances in this region", ) +@retry_with_instance_list(gpu_list(integ.test_region())) def test_hvd_gpu( mxnet_training_latest_version, mxnet_training_latest_py_version, sagemaker_session, - gpu_instance_type, tmpdir, + **kwargs, ): _create_and_fit_estimator( mxnet_training_latest_version, mxnet_training_latest_py_version, sagemaker_session, - gpu_instance_type, + kwargs["instance_type"], tmpdir, ) diff --git a/tests/integ/test_huggingface.py b/tests/integ/test_huggingface.py index 3d52ca44ea..5e9151fe86 100644 --- a/tests/integ/test_huggingface.py +++ b/tests/integ/test_huggingface.py @@ -69,12 +69,13 @@ def test_framework_processing_job_with_deps( and integ.test_region() in integ.TRAINING_NO_P3_REGIONS, reason="no ml.p2 or ml.p3 instances in this region", ) +@retry_with_instance_list(gpu_list(integ.test_region())) def test_huggingface_training( sagemaker_session, - gpu_instance_type, huggingface_training_latest_version, huggingface_training_pytorch_latest_version, huggingface_pytorch_latest_training_py_version, + **kwargs, ): with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): data_path = os.path.join(DATA_DIR, "huggingface") @@ -86,7 +87,7 @@ def test_huggingface_training( transformers_version=huggingface_training_latest_version, pytorch_version=huggingface_training_pytorch_latest_version, instance_count=1, - instance_type=gpu_instance_type, + instance_type=kwargs["instance_type"], hyperparameters={ "model_name_or_path": "distilbert-base-cased", "task_name": "wnli", diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index 30b9940e5d..86ac20e9bf 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -182,12 +182,42 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v @pytest.mark.release -def test_mnist_distributed( +def test_mnist_distributed_cpu( sagemaker_session, - instance_type, + cpu_instance_type, tensorflow_training_latest_version, tensorflow_training_latest_py_version, ): + _create_and_fit_estimator( + sagemaker_session, + tensorflow_training_latest_version, + tensorflow_training_latest_py_version, + cpu_instance_type, + ) + + +@pytest.mark.release +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS + and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="no ml.p2 or ml.p3 instances in this region", +) +@retry_with_instance_list(gpu_list(tests.integ.test_region())) +def test_mnist_distributed_gpu( + sagemaker_session, + tensorflow_training_latest_version, + tensorflow_training_latest_py_version, + **kwargs, +): + _create_and_fit_estimator( + sagemaker_session, + tensorflow_training_latest_version, + tensorflow_training_latest_py_version, + kwargs["instance_type"], + ) + + +def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instance_type): estimator = TensorFlow( entry_point=SCRIPT, source_dir=MNIST_RESOURCE_PATH, @@ -195,8 +225,8 @@ def test_mnist_distributed( instance_count=2, instance_type=instance_type, sagemaker_session=sagemaker_session, - framework_version=tensorflow_training_latest_version, - py_version=tensorflow_training_latest_py_version, + framework_version=tf_version, + py_version=py_version, distribution=PARAMETER_SERVER_DISTRIBUTION, disable_profiler=True, ) From 8da818a2444ba8ebed5016275cbb3d42d59bac6a Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 16 Jun 2022 18:37:43 +0000 Subject: [PATCH 086/526] prepare release v2.95.0 --- CHANGELOG.md | 18 ++++++++++++++++++ VERSION | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 976b95c0ea..fc2356c3f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## v2.95.0 (2022-06-16) + +### Features + + * Adding Training Compiler support for TensorFlow estimator starting TF 2.9 + * Add support for TF 2.9 training + +### Bug Fixes and Other Changes + + * integs fallback from p3 to p2 instance + * bucket exists check for session.default_bucket + * make instance type fields as optional + +### Documentation Changes + + * improvements on the docstring of ModelStep + * Add XGBoostProcessor + ## v2.94.0 (2022-06-07) ### Features diff --git a/VERSION b/VERSION index 237ee75884..f3071c00ff 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.94.1.dev0 +2.95.0 From 1be7024b330862ede0ace701037f1283b00eebb8 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 16 Jun 2022 18:37:44 +0000 Subject: [PATCH 087/526] update development version to v2.95.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index f3071c00ff..b5d39344d2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.95.0 +2.95.1.dev0 From 946009f10171c6f872f664ed3641e2c379e615e0 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Fri, 17 Jun 2022 11:25:22 -0700 Subject: [PATCH 088/526] feature: Add helper method to generate pipeline adjacency list (#3128) Co-authored-by: Namrata Madan --- src/sagemaker/workflow/_utils.py | 2 +- src/sagemaker/workflow/callback_step.py | 5 +- src/sagemaker/workflow/clarify_check_step.py | 7 +- src/sagemaker/workflow/condition_step.py | 15 +- src/sagemaker/workflow/conditions.py | 42 +++ src/sagemaker/workflow/emr_step.py | 2 +- src/sagemaker/workflow/entities.py | 5 + src/sagemaker/workflow/execution_variables.py | 6 + src/sagemaker/workflow/functions.py | 14 + src/sagemaker/workflow/lambda_step.py | 5 +- src/sagemaker/workflow/parameters.py | 5 + src/sagemaker/workflow/pipeline.py | 111 ++++++ src/sagemaker/workflow/properties.py | 68 ++-- src/sagemaker/workflow/quality_check_step.py | 11 +- src/sagemaker/workflow/step_collections.py | 3 +- src/sagemaker/workflow/steps.py | 92 ++++- tests/unit/sagemaker/workflow/helpers.py | 22 +- .../sagemaker/workflow/test_callback_step.py | 12 +- .../sagemaker/workflow/test_condition_step.py | 42 ++- .../sagemaker/workflow/test_conditions.py | 2 +- .../unit/sagemaker/workflow/test_emr_step.py | 12 +- .../unit/sagemaker/workflow/test_entities.py | 2 +- .../unit/sagemaker/workflow/test_fail_step.py | 5 +- .../unit/sagemaker/workflow/test_functions.py | 2 +- .../sagemaker/workflow/test_lambda_step.py | 20 +- .../sagemaker/workflow/test_model_step.py | 88 ++++- .../unit/sagemaker/workflow/test_pipeline.py | 41 +-- .../sagemaker/workflow/test_pipeline_graph.py | 342 ++++++++++++++++++ .../workflow/test_processing_step.py | 22 +- .../sagemaker/workflow/test_properties.py | 18 +- tests/unit/sagemaker/workflow/test_steps.py | 26 +- .../sagemaker/workflow/test_training_step.py | 8 +- .../sagemaker/workflow/test_transform_step.py | 4 +- .../sagemaker/workflow/test_tuning_step.py | 7 +- tests/unit/sagemaker/workflow/test_utils.py | 4 +- 35 files changed, 925 insertions(+), 147 deletions(-) create mode 100644 tests/unit/sagemaker/workflow/test_pipeline_graph.py diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8a0523c73a..7b8a3cdc25 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -369,7 +369,7 @@ def __init__( self.kwargs = kwargs self.container_def_list = container_def_list - self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelPackageOutput") + self._properties = Properties(step_name=name, shape_name="DescribeModelPackageOutput") @property def arguments(self) -> RequestType: diff --git a/src/sagemaker/workflow/callback_step.py b/src/sagemaker/workflow/callback_step.py index cd0b63f433..03903ef908 100644 --- a/src/sagemaker/workflow/callback_step.py +++ b/src/sagemaker/workflow/callback_step.py @@ -112,13 +112,12 @@ def __init__( self.cache_config = cache_config self.inputs = inputs - root_path = f"Steps.{name}" - root_prop = Properties(path=root_path) + root_prop = Properties(step_name=name) property_dict = {} for output in outputs: property_dict[output.output_name] = Properties( - f"{root_path}.OutputParameters['{output.output_name}']" + step_name=name, path=f"OutputParameters['{output.output_name}']" ) root_prop.__dict__["Outputs"] = property_dict diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index defb687e1f..f5c1193be8 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -236,13 +236,12 @@ def __init__( self._generate_processing_job_analysis_config(), self._baselining_processor ) - root_path = f"Steps.{name}" - root_prop = Properties(path=root_path) + root_prop = Properties(step_name=name) root_prop.__dict__["CalculatedBaselineConstraints"] = Properties( - f"{root_path}.CalculatedBaselineConstraints" + step_name=name, path="CalculatedBaselineConstraints" ) root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties( - f"{root_path}.BaselineUsedForDriftCheckConstraints" + step_name=name, path="BaselineUsedForDriftCheckConstraints" ) self._properties = root_prop diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index 1bac8353c0..e5797b5b63 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -77,9 +77,8 @@ def __init__( self.if_steps = if_steps or [] self.else_steps = else_steps or [] - root_path = f"Steps.{name}" - root_prop = Properties(path=root_path) - root_prop.__dict__["Outcome"] = Properties(f"{root_path}.Outcome") + root_prop = Properties(step_name=name) + root_prop.__dict__["Outcome"] = Properties(step_name=name, path="Outcome") self._properties = root_prop @property @@ -91,6 +90,11 @@ def arguments(self) -> RequestType: ElseSteps=list_to_request(self.else_steps), ) + @property + def step_only_arguments(self): + """Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`.""" + return self.conditions + @property def properties(self): """A simple Properties object with `Outcome` as the only property""" @@ -126,5 +130,10 @@ def expr(self): } } + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [self.step.name] + JsonGet = deprecated_class(JsonGet, "JsonGet") diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index df818f743b..67a3ea5396 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -17,6 +17,8 @@ """ from __future__ import absolute_import +import abc + from enum import Enum from typing import Dict, List, Union @@ -33,6 +35,7 @@ from sagemaker.workflow.execution_variables import ExecutionVariable from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.properties import Properties +from sagemaker.workflow.entities import PipelineVariable # TODO: consider base class for those with an expr method, rather than defining a type here ConditionValueType = Union[ExecutionVariable, Parameter, Properties] @@ -61,6 +64,11 @@ class Condition(Entity): condition_type: ConditionTypeEnum = attr.ib(factory=ConditionTypeEnum.factory) + @property + @abc.abstractmethod + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + @attr.s class ConditionComparison(Condition): @@ -84,6 +92,16 @@ def to_request(self) -> RequestType: "RightValue": primitive_or_expr(self.right), } + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + steps = [] + if isinstance(self.left, PipelineVariable): + steps.extend(self.left._referenced_steps) + if isinstance(self.right, PipelineVariable): + steps.extend(self.right._referenced_steps) + return steps + class ConditionEquals(ConditionComparison): """A condition for equality comparisons.""" @@ -213,6 +231,17 @@ def to_request(self) -> RequestType: "Values": [primitive_or_expr(in_value) for in_value in self.in_values], } + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + steps = [] + if isinstance(self.value, PipelineVariable): + steps.extend(self.value._referenced_steps) + for in_value in self.in_values: + if isinstance(in_value, PipelineVariable): + steps.extend(in_value._referenced_steps) + return steps + class ConditionNot(Condition): """A condition for negating another `Condition`.""" @@ -230,6 +259,11 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return {"Type": self.condition_type.value, "Expression": self.expression.to_request()} + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return self.expression._referenced_steps + class ConditionOr(Condition): """A condition for taking the logical OR of a list of `Condition` instances.""" @@ -250,6 +284,14 @@ def to_request(self) -> RequestType: "Conditions": [condition.to_request() for condition in self.conditions], } + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + steps = [] + for condition in self.conditions: + steps.extend(condition._referenced_steps) + return steps + def primitive_or_expr( value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties] diff --git a/src/sagemaker/workflow/emr_step.py b/src/sagemaker/workflow/emr_step.py index 6f30f92640..e7e154e8b1 100644 --- a/src/sagemaker/workflow/emr_step.py +++ b/src/sagemaker/workflow/emr_step.py @@ -94,7 +94,7 @@ def __init__( self.args = emr_step_args self.cache_config = cache_config - root_property = Properties(path=f"Steps.{name}", shape_name="Step", service_name="emr") + root_property = Properties(step_name=name, shape_name="Step", service_name="emr") root_property.__dict__["ClusterId"] = cluster_id self._properties = root_property diff --git a/src/sagemaker/workflow/entities.py b/src/sagemaker/workflow/entities.py index 7a984625bd..5272cfded6 100644 --- a/src/sagemaker/workflow/entities.py +++ b/src/sagemaker/workflow/entities.py @@ -102,3 +102,8 @@ def to_string(self): @abc.abstractmethod def expr(self) -> RequestType: """Get the expression structure for workflow service calls.""" + + @property + @abc.abstractmethod + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" diff --git a/src/sagemaker/workflow/execution_variables.py b/src/sagemaker/workflow/execution_variables.py index 22474c8856..516efb784e 100644 --- a/src/sagemaker/workflow/execution_variables.py +++ b/src/sagemaker/workflow/execution_variables.py @@ -13,6 +13,7 @@ """Pipeline parameters and conditions for workflow.""" from __future__ import absolute_import +from typing import List from sagemaker.workflow.entities import ( RequestType, PipelineVariable, @@ -42,6 +43,11 @@ def expr(self) -> RequestType: """The 'Get' expression dict for an `ExecutionVariable`.""" return {"Get": f"Execution.{self.name}"} + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [] + class ExecutionVariables: """Provide access to all available execution variables: diff --git a/src/sagemaker/workflow/functions.py b/src/sagemaker/workflow/functions.py index 36bd69fbff..53bcc5cc78 100644 --- a/src/sagemaker/workflow/functions.py +++ b/src/sagemaker/workflow/functions.py @@ -64,6 +64,15 @@ def expr(self): }, } + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + steps = [] + for value in self.values: + if isinstance(value, PipelineVariable): + steps.extend(value._referenced_steps) + return steps + @attr.s class JsonGet(PipelineVariable): @@ -96,3 +105,8 @@ def expr(self): "Path": self.json_path, } } + + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [self.step_name] diff --git a/src/sagemaker/workflow/lambda_step.py b/src/sagemaker/workflow/lambda_step.py index e9a5e98dc1..8161827a06 100644 --- a/src/sagemaker/workflow/lambda_step.py +++ b/src/sagemaker/workflow/lambda_step.py @@ -115,13 +115,12 @@ def __init__( self.cache_config = cache_config self.inputs = inputs if inputs is not None else {} - root_path = f"Steps.{name}" - root_prop = Properties(path=root_path) + root_prop = Properties(step_name=name) property_dict = {} for output in self.outputs: property_dict[output.output_name] = Properties( - f"{root_path}.OutputParameters['{output.output_name}']" + step_name=name, path=f"OutputParameters['{output.output_name}']" ) root_prop.__dict__["Outputs"] = property_dict diff --git a/src/sagemaker/workflow/parameters.py b/src/sagemaker/workflow/parameters.py index 875e2e50ff..3125eeb7c9 100644 --- a/src/sagemaker/workflow/parameters.py +++ b/src/sagemaker/workflow/parameters.py @@ -90,6 +90,11 @@ def expr(self) -> Dict[str, str]: """The 'Get' expression dict for a `Parameter`.""" return Parameter._expr(self.name) + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [] + @classmethod def _expr(cls, name): """An internal classmethod for the 'Get' expression dict for a `Parameter`. diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 64538b0dcc..f560945752 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -39,6 +39,7 @@ from sagemaker.workflow.properties import Properties from sagemaker.workflow.steps import Step from sagemaker.workflow.step_collections import StepCollection +from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.utilities import list_to_request @@ -534,3 +535,113 @@ def wait(self, delay=30, max_attempts=60): waiter_id, model, self.sagemaker_session.sagemaker_client ) waiter.wait(PipelineExecutionArn=self.arn) + + +class PipelineGraph: + """Helper class representing the Pipeline Directed Acyclic Graph (DAG) + + Attributes: + steps (Sequence[Union[Step, StepCollection]]): Sequence of `Step`s and/or `StepCollection`s + that represent each node in the pipeline DAG + """ + + def __init__(self, steps: Sequence[Union[Step, StepCollection]]): + self.step_map = {} + self._generate_step_map(steps) + self.adjacency_list = self._initialize_adjacency_list() + if self.is_cyclic(): + raise ValueError("Cycle detected in pipeline step graph.") + + def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]): + """Helper method to create a mapping from Step/Step Collection name to itself.""" + for step in steps: + if step.name in self.step_map: + raise ValueError("Pipeline steps cannot have duplicate names.") + self.step_map[step.name] = step + if isinstance(step, ConditionStep): + self._generate_step_map(step.if_steps + step.else_steps) + if isinstance(step, StepCollection): + self._generate_step_map(step.steps) + + @classmethod + def from_pipeline(cls, pipeline: Pipeline): + """Create a PipelineGraph object from the Pipeline object.""" + return cls(pipeline.steps) + + def _initialize_adjacency_list(self) -> Dict[str, List[str]]: + """Generate an adjacency list representing the step dependency DAG in this pipeline.""" + from collections import defaultdict + + dependency_list = defaultdict(set) + for step in self.step_map.values(): + if isinstance(step, Step): + dependency_list[step.name].update(step._find_step_dependencies(self.step_map)) + + if isinstance(step, ConditionStep): + for child_step in step.if_steps + step.else_steps: + if isinstance(child_step, Step): + dependency_list[child_step.name].add(step.name) + elif isinstance(child_step, StepCollection): + child_first_step = self.step_map[child_step.name].steps[0].name + dependency_list[child_first_step].add(step.name) + + adjacency_list = {} + for step in dependency_list: + for step_dependency in dependency_list[step]: + adjacency_list[step_dependency] = list( + set(adjacency_list.get(step_dependency, []) + [step]) + ) + for step in dependency_list: + if step not in adjacency_list: + adjacency_list[step] = [] + return adjacency_list + + def is_cyclic(self) -> bool: + """Check if this pipeline graph is cyclic. + + Returns true if it is cyclic, false otherwise. + """ + + def is_cyclic_helper(current_step): + visited_steps.add(current_step) + recurse_steps.add(current_step) + for child_step in self.adjacency_list[current_step]: + if child_step in recurse_steps: + return True + if child_step not in visited_steps: + if is_cyclic_helper(child_step): + return True + recurse_steps.remove(current_step) + return False + + visited_steps = set() + recurse_steps = set() + for step in self.adjacency_list: + if step not in visited_steps: + if is_cyclic_helper(step): + return True + return False + + def __iter__(self): + """Perform topological sort traversal of the Pipeline Graph.""" + + def topological_sort(current_step): + visited_steps.add(current_step) + for child_step in self.adjacency_list[current_step]: + if child_step not in visited_steps: + topological_sort(child_step) + self.stack.append(current_step) + + visited_steps = set() + self.stack = [] # pylint: disable=W0201 + for step in self.adjacency_list: + if step not in visited_steps: + topological_sort(step) + return self + + def __next__(self) -> Step: + """Return the next Step node from the Topological sort order.""" + + while self.stack: + return self.step_map.get(self.stack.pop()) + raise StopIteration diff --git a/src/sagemaker/workflow/properties.py b/src/sagemaker/workflow/properties.py index 480fddada1..41f3c98c5b 100644 --- a/src/sagemaker/workflow/properties.py +++ b/src/sagemaker/workflow/properties.py @@ -50,7 +50,8 @@ class Properties(PipelineVariable, metaclass=PropertiesMeta): def __init__( self, - path: str, + step_name: str, + path: str = None, shape_name: str = None, shape_names: List[str] = None, service_name: str = "sagemaker", @@ -58,11 +59,14 @@ def __init__( """Create a Properties instance representing the given shape. Args: - path (str): The parent path of the Properties instance. + step_name (str): The name of the Step this Property belongs to. + path (str): The relative path of this Property value. shape_name (str): The botocore service model shape name. shape_names (str): A List of the botocore service model shape name. """ - self._path = path + self.step_name = step_name + self.path = path + shape_names = [] if shape_names is None else shape_names self._shape_names = shape_names if shape_name is None else [shape_name] + shape_names @@ -78,35 +82,54 @@ def __init__( for key, info in members.items(): if shapes.get(info["shape"], {}).get("type") == "list": self.__dict__[key] = PropertiesList( - f"{path}.{key}", info["shape"], service_name + step_name=step_name, + path=".".join(filter(None, (path, key))), + shape_name=info["shape"], + service_name=service_name, ) elif shapes.get(info["shape"], {}).get("type") == "map": self.__dict__[key] = PropertiesMap( - f"{path}.{key}", info["shape"], service_name + step_name=step_name, + path=".".join(filter(None, (path, key))), + shape_name=info["shape"], + service_name=service_name, ) else: self.__dict__[key] = Properties( - f"{path}.{key}", info["shape"], service_name=service_name + step_name=step_name, + path=".".join(filter(None, (path, key))), + shape_name=info["shape"], + service_name=service_name, ) @property def expr(self): """The 'Get' expression dict for a `Properties`.""" - return {"Get": self._path} + prefix = f"Steps.{self.step_name}" + full_path = prefix if self.path is None else f"{prefix}.{self.path}" + return {"Get": full_path} + + @property + def _referenced_steps(self) -> List[str]: + """List of step names that this function depends on.""" + return [self.step_name] class PropertiesList(Properties): """PropertiesList for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): + def __init__( + self, step_name: str, path: str, shape_name: str = None, service_name: str = "sagemaker" + ): """Create a PropertiesList instance representing the given shape. Args: - path (str): The parent path of the PropertiesList instance. + step_name (str): The name of the Step this Property belongs to. + path (str): The relative path of this Property value. shape_name (str): The botocore service model shape name. service_name (str): The botocore service name. """ - super(PropertiesList, self).__init__(path, shape_name) + super(PropertiesList, self).__init__(step_name, path, shape_name) self.shape_name = shape_name self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() @@ -121,9 +144,9 @@ def __getitem__(self, item: Union[int, str]): shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["member"]["shape"] if isinstance(item, str): - property_item = Properties(f"{self._path}['{item}']", member) + property_item = Properties(self.step_name, f"{self.path}['{item}']", member) else: - property_item = Properties(f"{self._path}[{item}]", member) + property_item = Properties(self.step_name, f"{self.path}[{item}]", member) self._items[item] = property_item return self._items.get(item) @@ -132,15 +155,18 @@ def __getitem__(self, item: Union[int, str]): class PropertiesMap(Properties): """PropertiesMap for use in workflow expressions.""" - def __init__(self, path: str, shape_name: str = None, service_name: str = "sagemaker"): + def __init__( + self, step_name: str, path: str, shape_name: str = None, service_name: str = "sagemaker" + ): """Create a PropertiesMap instance representing the given shape. Args: - path (str): The parent path of the PropertiesMap instance. - shape_name (str): The botocore sagemaker service model shape name. + step_name (str): The name of the Step this Property belongs to. + path (str): The relative path of this Property value. + shape_name (str): The botocore service model shape name. service_name (str): The botocore service name. """ - super(PropertiesMap, self).__init__(path, shape_name) + super(PropertiesMap, self).__init__(step_name, path, shape_name) self.shape_name = shape_name self.service_name = service_name self._items: Dict[Union[int, str], Properties] = dict() @@ -155,9 +181,9 @@ def __getitem__(self, item: Union[int, str]): shape = Properties._shapes_map.get(self.service_name, {}).get(self.shape_name) member = shape["value"]["shape"] if isinstance(item, str): - property_item = Properties(f"{self._path}['{item}']", member) + property_item = Properties(self.step_name, f"{self.path}['{item}']", member) else: - property_item = Properties(f"{self._path}[{item}]", member) + property_item = Properties(self.step_name, f"{self.path}[{item}]", member) self._items[item] = property_item return self._items.get(item) @@ -168,9 +194,9 @@ class PropertyFile(Expression): """Provides a property file struct. Attributes: - name: The name of the property file for reference with `JsonGet` functions. - output_name: The name of the processing job output channel. - path: The path to the file at the output channel location. + name (str): The name of the property file for reference with `JsonGet` functions. + output_name (str): The name of the processing job output channel. + path (str): The path to the file at the output channel location. """ name: str = attr.ib() diff --git a/src/sagemaker/workflow/quality_check_step.py b/src/sagemaker/workflow/quality_check_step.py index 092d60434a..d9d3ea2bef 100644 --- a/src/sagemaker/workflow/quality_check_step.py +++ b/src/sagemaker/workflow/quality_check_step.py @@ -203,19 +203,18 @@ def __init__( ], ) - root_path = f"Steps.{name}" - root_prop = Properties(path=root_path) + root_prop = Properties(step_name=name) root_prop.__dict__["CalculatedBaselineConstraints"] = Properties( - f"{root_path}.CalculatedBaselineConstraints" + step_name=name, path="CalculatedBaselineConstraints" ) root_prop.__dict__["CalculatedBaselineStatistics"] = Properties( - f"{root_path}.CalculatedBaselineStatistics" + step_name=name, path="CalculatedBaselineStatistics" ) root_prop.__dict__["BaselineUsedForDriftCheckStatistics"] = Properties( - f"{root_path}.BaselineUsedForDriftCheckStatistics" + step_name=name, path="BaselineUsedForDriftCheckStatistics" ) root_prop.__dict__["BaselineUsedForDriftCheckConstraints"] = Properties( - f"{root_path}.BaselineUsedForDriftCheckConstraints" + step_name=name, path="BaselineUsedForDriftCheckConstraints" ) self._properties = root_prop diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index bc7deb4fa3..d52ddace87 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -50,8 +50,7 @@ def properties(self): """The properties of the particular `StepCollection`.""" if not self.steps: return None - size = len(self.steps) - return self.steps[size - 1].properties + return self.steps[-1].properties class RegisterModel(StepCollection): # pragma: no cover diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 01c5e6d18d..45d38fe26d 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -17,7 +17,7 @@ import warnings from enum import Enum -from typing import Dict, List, Union, Optional, TYPE_CHECKING +from typing import Dict, List, Set, Union, Optional, Any, TYPE_CHECKING from urllib.parse import urlparse import attr @@ -35,6 +35,7 @@ ) from sagemaker.transformer import Transformer, _TransformJob from sagemaker.tuner import HyperparameterTuner, _TuningJob +from sagemaker.workflow.conditions import Condition from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import ( DefaultEnumMeta, @@ -46,6 +47,7 @@ PropertyFile, Properties, ) +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join from sagemaker.workflow.retry import RetryPolicy @@ -96,6 +98,15 @@ class Step(Entity): def arguments(self) -> RequestType: """The arguments to the particular `Step` service call.""" + @property + def step_only_arguments(self) -> RequestType: + """The arguments to this Step only. + + Compound Steps such as the ConditionStep will have to + override this method to return arguments pertaining to only that step. + """ + return self.arguments + @property @abc.abstractmethod def properties(self): @@ -148,9 +159,70 @@ def _resolve_depends_on( elif isinstance(step, str): depends_on.append(step) else: - raise ValueError(f"Invalid input step name: {step}") + raise ValueError(f"Invalid input step type: {type(step)}") return depends_on + def _find_step_dependencies( + self, step_map: Dict[str, Union["Step", "StepCollection"]] + ) -> List[str]: + """Find the all step names this step is dependent on.""" + step_dependencies = set() + if self.depends_on: + step_dependencies.update(self._find_dependencies_in_depends_on_list(step_map)) + step_dependencies.update( + self._find_dependencies_in_step_arguments(self.step_only_arguments, step_map) + ) + return list(step_dependencies) + + def _find_dependencies_in_depends_on_list( + self, step_map: Dict[str, Union["Step", "StepCollection"]] + ) -> Set[str]: + """Find dependency steps referenced in the depends-on field of this step.""" + # import here to prevent circular import + from sagemaker.workflow.step_collections import StepCollection + + dependencies = set() + for step in self.depends_on: + if isinstance(step, Step): + dependencies.add(step.name) + elif isinstance(step, StepCollection): + dependencies.add(step.steps[-1].name) + elif isinstance(step, str): + # step could be the name of a `Step` or a `StepCollection` + dependencies.add(self._get_step_name_from_str(step, step_map)) + return dependencies + + @staticmethod + def _find_dependencies_in_step_arguments( + obj: Any, step_map: Dict[str, Union["Step", "StepCollection"]] + ): + """Find the step dependencies referenced in the arguments of this step.""" + dependencies = set() + if isinstance(obj, dict): + for value in obj.values(): + if isinstance(value, (PipelineVariable, Condition)): + for referenced_step in value._referenced_steps: + dependencies.add(Step._get_step_name_from_str(referenced_step, step_map)) + dependencies.update(Step._find_dependencies_in_step_arguments(value, step_map)) + elif isinstance(obj, list): + for item in obj: + if isinstance(item, (PipelineVariable, Condition)): + for referenced_step in item._referenced_steps: + dependencies.add(Step._get_step_name_from_str(referenced_step, step_map)) + dependencies.update(Step._find_dependencies_in_step_arguments(item, step_map)) + return dependencies + + @staticmethod + def _get_step_name_from_str( + str_input: str, step_map: Dict[str, Union["Step", "StepCollection"]] + ) -> str: + """Convert a Step or StepCollection name input to step name.""" + from sagemaker.workflow.step_collections import StepCollection + + if isinstance(step_map[str_input], StepCollection): + return step_map[str_input].steps[-1].name + return str_input + @attr.s class CacheConfig: @@ -302,9 +374,7 @@ def __init__( self.estimator = estimator self.inputs = inputs - self._properties = Properties( - path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse" - ) + self._properties = Properties(step_name=name, shape_name="DescribeTrainingJobResponse") self.cache_config = cache_config if self.cache_config: @@ -442,7 +512,7 @@ def __init__( self.model = model self.inputs = inputs or CreateModelInput() - self._properties = Properties(path=f"Steps.{name}", shape_name="DescribeModelOutput") + self._properties = Properties(step_name=name, shape_name="DescribeModelOutput") warnings.warn( ( @@ -549,9 +619,7 @@ def __init__( self.transformer = transformer self.inputs = inputs self.cache_config = cache_config - self._properties = Properties( - path=f"Steps.{name}", shape_name="DescribeTransformJobResponse" - ) + self._properties = Properties(step_name=name, shape_name="DescribeTransformJobResponse") if not self.step_args: if inputs is None: @@ -684,9 +752,7 @@ def __init__( self.job_name = None self.kms_key = kms_key self.cache_config = cache_config - self._properties = Properties( - path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse" - ) + self._properties = Properties(step_name=name, shape_name="DescribeProcessingJobResponse") if not self.step_args: # Examine why run method in `sagemaker.processing.Processor` @@ -852,7 +918,7 @@ def __init__( self.inputs = inputs self.job_arguments = job_arguments self._properties = Properties( - path=f"Steps.{name}", + step_name=name, shape_names=[ "DescribeHyperParameterTuningJobResponse", "ListTrainingJobsForHyperParameterTuningJobResponse", diff --git a/tests/unit/sagemaker/workflow/helpers.py b/tests/unit/sagemaker/workflow/helpers.py index 67405e9372..ebc3bbd959 100644 --- a/tests/unit/sagemaker/workflow/helpers.py +++ b/tests/unit/sagemaker/workflow/helpers.py @@ -15,6 +15,7 @@ from sagemaker.workflow.properties import Properties from sagemaker.workflow.steps import Step, StepTypeEnum +from sagemaker.workflow.step_collections import StepCollection def ordered(obj): @@ -37,19 +38,32 @@ def ordered(obj): class CustomStep(Step): - def __init__(self, name, display_name=None, description=None, depends_on=None): + def __init__(self, name, input_data=None, display_name=None, description=None, depends_on=None): + self.input_data = input_data super(CustomStep, self).__init__( name, display_name, description, StepTypeEnum.TRAINING, depends_on ) # for testing property reference, we just use DescribeTrainingJobResponse shape here. - self._properties = Properties( - path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse" - ) + self._properties = Properties(name, shape_name="DescribeTrainingJobResponse") @property def arguments(self): + if self.input_data: + return {"input_data": self.input_data} return dict() @property def properties(self): return self._properties + + +class CustomStepCollection(StepCollection): + def __init__(self, name, num_steps=2, depends_on=None): + steps = [] + previous_step = None + for i in range(num_steps): + step_depends_on = depends_on if not previous_step else [previous_step] + step = CustomStep(name=f"{name}-{i}", depends_on=step_depends_on) + steps.append(step) + previous_step = step + super(CustomStepCollection, self).__init__(name, steps) diff --git a/tests/unit/sagemaker/workflow/test_callback_step.py b/tests/unit/sagemaker/workflow/test_callback_step.py index fda814a786..abf335577b 100644 --- a/tests/unit/sagemaker/workflow/test_callback_step.py +++ b/tests/unit/sagemaker/workflow/test_callback_step.py @@ -19,9 +19,9 @@ from mock import Mock from sagemaker.workflow.parameters import ParameterInteger, ParameterString -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered @pytest.fixture @@ -156,3 +156,11 @@ def test_pipeline_interpolates_callback_outputs(): }, ], } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyCallbackStep1": [], + "MyCallbackStep2": [], + "TestStep": ["MyCallbackStep1", "MyCallbackStep2"], + } + ) diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index 21bf28e1cb..f3d6209f23 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -12,10 +12,27 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import pytest +from mock import Mock, MagicMock from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.parameters import ParameterInteger from sagemaker.workflow.condition_step import ConditionStep -from tests.unit.sagemaker.workflow.helpers import CustomStep +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name="us-west-2") + session_mock = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name="us-west-2", + config=None, + local_mode=False, + account_id=Mock(), + ) + return session_mock def test_condition_step(): @@ -60,3 +77,26 @@ def test_condition_step(): }, } assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"} + + +def test_pipeline(sagemaker_session): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + custom_step1 = CustomStep("IfStep") + custom_step2 = CustomStep("ElseStep") + step_cond = ConditionStep( + name="CondStep", + conditions=[cond], + if_steps=[custom_step1], + else_steps=[custom_step2], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step_cond], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"CondStep": ["IfStep", "ElseStep"], "IfStep": [], "ElseStep": []} + ) diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index f4bea55b6e..f5afce9de0 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -112,7 +112,7 @@ def test_condition_in_mixed(): assert cond_in.to_request() == { "Type": "In", "QueryValue": {"Get": "Parameters.MyStr"}, - "Values": ["abc", {"Get": "foo"}, {"Get": "Execution.StartDateTime"}], + "Values": ["abc", {"Get": "Steps.foo"}, {"Get": "Execution.StartDateTime"}], } diff --git a/tests/unit/sagemaker/workflow/test_emr_step.py b/tests/unit/sagemaker/workflow/test_emr_step.py index b9c5335648..703af7ab4e 100644 --- a/tests/unit/sagemaker/workflow/test_emr_step.py +++ b/tests/unit/sagemaker/workflow/test_emr_step.py @@ -20,9 +20,9 @@ from sagemaker.workflow.emr_step import EMRStep, EMRStepConfig from sagemaker.workflow.steps import CacheConfig -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.parameters import ParameterString -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered @pytest.fixture() @@ -108,7 +108,7 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session): cluster_id="MyClusterID", display_name="emr_step_1", description="MyEMRStepDescription", - depends_on=["TestStep"], + depends_on=[custom_step], step_config=emr_step_config_1, ) @@ -119,7 +119,7 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session): cluster_id="MyClusterID", display_name="emr_step_2", description="MyEMRStepDescription", - depends_on=["TestStep"], + depends_on=[custom_step], step_config=emr_step_config_2, ) @@ -180,3 +180,7 @@ def test_pipeline_interpolates_emr_outputs(sagemaker_session): }, ], } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"emr_step_1": [], "emr_step_2": [], "TestStep": ["emr_step_1", "emr_step_2"]} + ) diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index 03c4fd22a1..6f0be2ccca 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -121,7 +121,7 @@ def test_pipeline_variable_in_pipeline_definition(sagemaker_session): property_file=property_file, json_path="my-json-path", ) - prop = Properties("Steps.MyStep", "DescribeProcessingJobResponse") + prop = Properties(step_name="MyStep", shape_name="DescribeProcessingJobResponse") cond = ConditionGreaterThan(left=param_str, right=param_int.to_string()) step_fail = FailStep( diff --git a/tests/unit/sagemaker/workflow/test_fail_step.py b/tests/unit/sagemaker/workflow/test_fail_step.py index 04edaf0ac5..af3dff195b 100644 --- a/tests/unit/sagemaker/workflow/test_fail_step.py +++ b/tests/unit/sagemaker/workflow/test_fail_step.py @@ -21,7 +21,8 @@ from sagemaker.workflow.fail_step import FailStep from sagemaker.workflow.functions import Join from sagemaker.workflow.parameters import ParameterInteger -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from tests.unit.sagemaker.workflow.helpers import ordered def test_fail_step(): @@ -104,6 +105,8 @@ def test_fail_step_with_join_fn_in_error_message(): ] assert json.loads(pipeline.definition())["Steps"] == _expected_dsl + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered({"CondStep": ["FailStep"], "FailStep": []}) def test_fail_step_with_properties_ref(): diff --git a/tests/unit/sagemaker/workflow/test_functions.py b/tests/unit/sagemaker/workflow/test_functions.py index 5be74eab7e..29dbbef34f 100644 --- a/tests/unit/sagemaker/workflow/test_functions.py +++ b/tests/unit/sagemaker/workflow/test_functions.py @@ -50,7 +50,7 @@ def test_join_expressions(): ParameterFloat(name="MyFloat"), ParameterInteger(name="MyInt"), ParameterString(name="MyStr"), - Properties(path="Steps.foo.OutputPath.S3Uri"), + Properties(step_name="foo", path="OutputPath.S3Uri"), ExecutionVariables.PIPELINE_EXECUTION_ID, Join(on=",", values=[1, "a", False, 1.1]), ] diff --git a/tests/unit/sagemaker/workflow/test_lambda_step.py b/tests/unit/sagemaker/workflow/test_lambda_step.py index d18462d156..451fb53beb 100644 --- a/tests/unit/sagemaker/workflow/test_lambda_step.py +++ b/tests/unit/sagemaker/workflow/test_lambda_step.py @@ -19,11 +19,11 @@ from mock import Mock, MagicMock from sagemaker.workflow.parameters import ParameterInteger, ParameterString -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from sagemaker.lambda_helper import Lambda from sagemaker.workflow.steps import CacheConfig -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered @pytest.fixture() @@ -63,7 +63,7 @@ def test_lambda_step(sagemaker_session): cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") lambda_step = LambdaStep( name="MyLambdaStep", - depends_on=["TestStep"], + depends_on=[custom_step1], lambda_func=Lambda( function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, @@ -74,7 +74,7 @@ def test_lambda_step(sagemaker_session): outputs=[output_param1, output_param2], cache_config=cache_config, ) - lambda_step.add_depends_on(["SecondTestStep"]) + lambda_step.add_depends_on([custom_step2]) pipeline = Pipeline( name="MyPipeline", parameters=[param], @@ -95,6 +95,10 @@ def test_lambda_step(sagemaker_session): "Arguments": {"arg1": "foo", "arg2": 5, "arg3": {"Get": "Parameters.MyInt"}}, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyLambdaStep": [], "TestStep": ["MyLambdaStep"], "SecondTestStep": ["MyLambdaStep"]} + ) def test_lambda_step_output_expr(sagemaker_session): @@ -127,7 +131,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): output_param2 = LambdaOutput(output_name="output2", output_type=LambdaOutputTypeEnum.String) lambda_step1 = LambdaStep( name="MyLambdaStep1", - depends_on=["TestStep"], + depends_on=[custom_step], lambda_func=Lambda( function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, @@ -137,7 +141,7 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): ) lambda_step2 = LambdaStep( name="MyLambdaStep2", - depends_on=["TestStep"], + depends_on=[custom_step], lambda_func=Lambda( function_arn="arn:aws:lambda:us-west-2:123456789012:function:sagemaker_test_lambda", session=sagemaker_session, @@ -185,6 +189,10 @@ def test_pipeline_interpolates_lambda_outputs(sagemaker_session): }, ], } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyLambdaStep1": [], "MyLambdaStep2": [], "TestStep": ["MyLambdaStep1", "MyLambdaStep2"]} + ) def test_lambda_step_no_inputs_outputs(sagemaker_session): diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index c233bd0290..68961b355c 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -41,7 +41,7 @@ _REPACK_MODEL_NAME_BASE, ) from sagemaker.workflow.parameters import ParameterString, ParameterInteger -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.retry import ( StepRetryPolicy, @@ -53,7 +53,7 @@ from sagemaker.lambda_helper import Lambda from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from tests.unit import DATA_DIR -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered _IMAGE_URI = "fakeimage" _REGION = "us-west-2" @@ -146,7 +146,7 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param, transform_instances=["ml.m5.xlarge"], model_package_group_name="MyModelPackageGroup", ) - model_steps = ModelStep( + model_step = ModelStep( name="MyModelStep", step_args=step_args, retry_policies=dict( @@ -162,17 +162,18 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param, depends_on=["TestStep"], description="my model step description", ) + custom_step2 = CustomStep("TestStep2", depends_on=[model_step]) pipeline = Pipeline( name="MyPipeline", parameters=[model_data_param], - steps=[model_steps, custom_step], + steps=[custom_step, model_step, custom_step2], sagemaker_session=pipeline_session, ) step_dsl_list = json.loads(pipeline.definition())["Steps"] - assert len(step_dsl_list) == 3 + assert len(step_dsl_list) == 4 expected_repack_step_name = f"MyModelStep-{_REPACK_MODEL_NAME_BASE}-MyModel" # Filter out the dummy custom step - step_dsl_list = list(filter(lambda s: s["Name"] != "TestStep", step_dsl_list)) + step_dsl_list = list(filter(lambda s: not s["Name"].startswith("TestStep"), step_dsl_list)) for step in step_dsl_list[0:2]: if step["Type"] == "Training": assert step["Name"] == expected_repack_step_name @@ -223,6 +224,16 @@ def test_register_model_with_runtime_repack(pipeline_session, model_data_param, else: raise Exception("A step exists in the collection of an invalid type.") + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "TestStep": ["MyModelStep-RepackModel-MyModel"], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-RegisterModel"], + "MyModelStep-RegisterModel": ["TestStep2"], + "TestStep2": [], + } + ) + def test_create_model_with_runtime_repack(pipeline_session, model_data_param, model): step_args = model.create( @@ -289,6 +300,13 @@ def test_create_model_with_runtime_repack(pipeline_session, model_data_param, mo ] else: raise Exception("A step exists in the collection of an invalid type.") + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyModelStep-CreateModel": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + } + ) def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_param, model): @@ -376,6 +394,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_ ] else: raise Exception("A step exists in the collection of an invalid type.") + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyModelStep-CreateModel": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + } + ) def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_data_param): @@ -411,14 +436,17 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat name="MyModelStep", step_args=step_args, ) + custom_step = CustomStep("TestStep", input_data=model_steps.properties.ModelApprovalStatus) pipeline = Pipeline( name="MyPipeline", parameters=[model_data_param], - steps=[model_steps], + steps=[model_steps, custom_step], sagemaker_session=pipeline_session, ) step_dsl_list = json.loads(pipeline.definition())["Steps"] - assert len(step_dsl_list) == 2 + assert len(step_dsl_list) == 3 + # Filter out the dummy custom step + step_dsl_list = list(filter(lambda s: not s["Name"].startswith("TestStep"), step_dsl_list)) expected_repack_step_name = f"MyModelStep-{_REPACK_MODEL_NAME_BASE}-1" for step in step_dsl_list: if step["Type"] == "Training": @@ -453,6 +481,14 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat assert containers[1]["Environment"]["k"] == "v" else: raise Exception("A step exists in the collection of an invalid type.") + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyModelStep-RegisterModel": ["TestStep"], + "MyModelStep-RepackModel-1": ["MyModelStep-RegisterModel"], + "TestStep": [], + } + ) def test_register_model_without_repack(pipeline_session): @@ -501,6 +537,8 @@ def test_register_model_without_repack(pipeline_session): containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz" ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []}) @patch("sagemaker.utils.repack_model") @@ -538,6 +576,10 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME assert len(step_dsl_list[0]["DependsOn"]) == 1 assert step_dsl_list[0]["DependsOn"][0] == "TestStep" + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyModelStep-CreateModel": [], "TestStep": ["MyModelStep-CreateModel"]} + ) def test_conditional_model_create_and_regis( @@ -555,7 +597,7 @@ def test_conditional_model_create_and_regis( model_package_group_name="MyModelPackageGroup", ) step_model_regis = ModelStep( - name="MyModelStep", + name="MyModelStepRegis", step_args=step_args, ) # create model without runtime repack @@ -566,7 +608,7 @@ def test_conditional_model_create_and_regis( accelerator_type="ml.eia1.medium", ) step_model_create = ModelStep( - name="MyModelStep", + name="MyModelStepCreate", step_args=step_args, ) step_cond = ConditionStep( @@ -586,7 +628,7 @@ def test_conditional_model_create_and_regis( cond_step_dsl = json.loads(pipeline.definition())["Steps"][0] step_dsl_list = cond_step_dsl["Arguments"]["IfSteps"] + cond_step_dsl["Arguments"]["ElseSteps"] assert len(step_dsl_list) == 3 - expected_repack_step_name = f"MyModelStep-{_REPACK_MODEL_NAME_BASE}-MyModel" + expected_repack_step_name = f"MyModelStepRegis-{_REPACK_MODEL_NAME_BASE}-MyModel" for step in step_dsl_list: if step["Type"] == "Training": assert step["Name"] == expected_repack_step_name @@ -601,7 +643,7 @@ def test_conditional_model_create_and_regis( assert "s3://" in arguments["HyperParameters"]["sagemaker_submit_directory"] assert arguments["HyperParameters"]["dependencies"] == "null" elif step["Type"] == "RegisterModel": - assert step["Name"] == f"MyModelStep-{_REGISTER_MODEL_NAME_BASE}" + assert step["Name"] == f"MyModelStepRegis-{_REGISTER_MODEL_NAME_BASE}" arguments = step["Arguments"] assert arguments["ModelApprovalStatus"] == "PendingManualApproval" assert len(arguments["InferenceSpecification"]["Containers"]) == 1 @@ -613,7 +655,7 @@ def test_conditional_model_create_and_regis( assert container["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert container["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME elif step["Type"] == "Model": - assert step["Name"] == f"MyModelStep-{_CREATE_MODEL_NAME_BASE}" + assert step["Name"] == f"MyModelStepCreate-{_CREATE_MODEL_NAME_BASE}" arguments = step["Arguments"] container = arguments["PrimaryContainer"] assert container["Image"] == _IMAGE_URI @@ -621,6 +663,18 @@ def test_conditional_model_create_and_regis( assert not container.get("Environment", {}) else: raise Exception("A step exists in the collection of an invalid type.") + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyModelStepCreate-CreateModel": [], + "MyModelStepRegis-RegisterModel": [], + "MyModelStepRegis-RepackModel-MyModel": ["MyModelStepRegis-RegisterModel"], + "cond-good-enough": [ + "MyModelStepCreate-CreateModel", + "MyModelStepRegis-RepackModel-MyModel", + ], + } + ) @pytest.mark.parametrize( @@ -918,6 +972,14 @@ def test_model_step_with_lambda_property_reference(pipeline_session): assert register_step["Arguments"]["PrimaryContainer"]["Image"] == { "Get": "Steps.MyLambda.OutputParameters['model_image']" } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyLambda": ["mymodelstep-CreateModel", "mymodelstep-RepackModel-MyModel"], + "mymodelstep-CreateModel": [], + "mymodelstep-RepackModel-MyModel": ["mymodelstep-CreateModel"], + } + ) @pytest.mark.parametrize( diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index f39a012df8..a9e9474013 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -22,37 +22,13 @@ from sagemaker import s3 from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import ParameterString -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.parallelism_config import ParallelismConfiguration from sagemaker.workflow.pipeline_experiment_config import ( PipelineExperimentConfig, PipelineExperimentConfigProperties, ) -from sagemaker.workflow.properties import Properties -from sagemaker.workflow.steps import ( - Step, - StepTypeEnum, -) -from tests.unit.sagemaker.workflow.helpers import ordered - - -class CustomStep(Step): - def __init__(self, name, input_data, display_name=None, description=None): - self.input_data = input_data - super(CustomStep, self).__init__(name, display_name, description, StepTypeEnum.TRAINING) - - path = f"Steps.{name}" - prop = Properties(path=path) - prop.__dict__["S3Uri"] = Properties(f"{path}.S3Uri") - self._properties = prop - - @property - def arguments(self): - return {"input_data": self.input_data} - - @property - def properties(self): - return self._properties +from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep @pytest.fixture @@ -314,7 +290,9 @@ def test_pipeline_two_step(sagemaker_session_mock): PipelineExperimentConfigProperties.EXPERIMENT_NAME, # experiment config property ], ) - step2 = CustomStep(name="MyStep2", input_data=[step1.properties.S3Uri]) # step property + step2 = CustomStep( + name="MyStep2", input_data=[step1.properties.ModelArtifacts.S3ModelArtifacts] + ) # step property pipeline = Pipeline( name="MyPipeline", parameters=[parameter], @@ -344,7 +322,7 @@ def test_pipeline_two_step(sagemaker_session_mock): { "Name": "MyStep2", "Type": "Training", - "Arguments": {"input_data": [step1.properties.S3Uri]}, + "Arguments": {"input_data": [step1.properties.ModelArtifacts.S3ModelArtifacts]}, }, ], } @@ -372,12 +350,17 @@ def test_pipeline_two_step(sagemaker_session_mock): { "Name": "MyStep2", "Type": "Training", - "Arguments": {"input_data": [{"Get": "Steps.MyStep1.S3Uri"}]}, + "Arguments": { + "input_data": [{"Get": "Steps.MyStep1.ModelArtifacts.S3ModelArtifacts"}] + }, }, ], } ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered({"MyStep1": ["MyStep2"], "MyStep2": []}) + def test_pipeline_override_experiment_config(): pipeline = Pipeline( diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py new file mode 100644 index 0000000000..b7d69e617a --- /dev/null +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -0,0 +1,342 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from mock import Mock + +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ( + ConditionEquals, + ConditionIn, + ConditionNot, + ConditionOr, + ConditionGreaterThan, +) +from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.parameters import ParameterInteger, ParameterString +from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep, CustomStepCollection + + +@pytest.fixture +def sagemaker_session_mock(): + session_mock = Mock() + session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") + return session_mock + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +def test_pipeline_duplicate_step_name(sagemaker_session_mock): + step1 = CustomStep(name="foo") + step2 = CustomStep(name="foo") + pipeline = Pipeline( + name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock + ) + with pytest.raises(ValueError) as error: + PipelineGraph.from_pipeline(pipeline) + assert "Pipeline steps cannot have duplicate names." in str(error.value) + + +def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + custom_step = CustomStep(name="foo") + custom_step2 = CustomStep(name="foo") + condition_step = ConditionStep( + name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2] + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, condition_step], + sagemaker_session=sagemaker_session_mock, + ) + with pytest.raises(ValueError) as error: + PipelineGraph.from_pipeline(pipeline) + assert "Pipeline steps cannot have duplicate names." in str(error.value) + + +def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock): + custom_step = CustomStep(name="foo-1") + custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step]) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, custom_step_collection], + sagemaker_session=sagemaker_session_mock, + ) + with pytest.raises(ValueError) as error: + PipelineGraph.from_pipeline(pipeline) + assert "Pipeline steps cannot have duplicate names." in str(error.value) + + +def test_pipeline_graph_acyclic(sagemaker_session_mock): + step_a = CustomStep(name="stepA") + step_b = CustomStep(name="stepB") + step_c = CustomStep(name="stepC", depends_on=[step_a]) + step_d = CustomStep(name="stepD", depends_on=[step_c]) + step_e = CustomStep(name="stepE", depends_on=[step_a, step_b, step_d]) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step_a, step_b, step_c, step_d, step_e], + sagemaker_session=sagemaker_session_mock, + ) + + pipeline_graph = PipelineGraph.from_pipeline(pipeline) + adjacency_list = pipeline_graph.adjacency_list + assert ordered(adjacency_list) == ordered( + { + "stepA": ["stepC", "stepE"], + "stepB": ["stepE"], + "stepC": ["stepD"], + "stepD": ["stepE"], + "stepE": [], + } + ) + _verify_pipeline_graph_traversal(pipeline_graph) + + +def test_pipeline_graph_acyclic_with_condition_step_explicit_dependency(sagemaker_session_mock): + custom_step = CustomStep(name="TestStep") + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + param = ParameterInteger(name="MyInt", default_value=2) + cond = ConditionEquals(left=param, right=1) + condition_step = ConditionStep( + name="condStep", + conditions=[cond], + depends_on=[custom_step], + if_steps=[if_step], + else_steps=[else_step], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, condition_step], + sagemaker_session=sagemaker_session_mock, + ) + + pipeline_graph = PipelineGraph.from_pipeline(pipeline) + adjacency_list = pipeline_graph.adjacency_list + assert ordered(adjacency_list) == ordered( + {"condStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": [], "TestStep": ["condStep"]} + ) + _verify_pipeline_graph_traversal(pipeline_graph) + + +def test_pipeline_graph_acyclic_with_condition_step_property_reference_dependency( + sagemaker_session_mock, +): + custom_step = CustomStep(name="TestStep") + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + cond = ConditionEquals(left=custom_step.properties.TrainingJobStatus, right="Succeeded") + condition_step = ConditionStep( + name="condStep", conditions=[cond], if_steps=[if_step], else_steps=[else_step] + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, condition_step], + sagemaker_session=sagemaker_session_mock, + ) + + pipeline_graph = PipelineGraph.from_pipeline(pipeline) + adjacency_list = pipeline_graph.adjacency_list + assert ordered(adjacency_list) == ordered( + {"condStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": [], "TestStep": ["condStep"]} + ) + _verify_pipeline_graph_traversal(pipeline_graph) + + +def test_pipeline_graph_acyclic_with_step_collection_explicit_dependency(sagemaker_session_mock): + custom_step1 = CustomStep(name="TestStep") + custom_step_collection = CustomStepCollection( + name="TestStepCollection", depends_on=[custom_step1] + ) + custom_step2 = CustomStep(name="TestStep2", depends_on=[custom_step_collection]) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step1, custom_step_collection, custom_step2], + sagemaker_session=sagemaker_session_mock, + ) + + pipeline_graph = PipelineGraph.from_pipeline(pipeline) + adjacency_list = pipeline_graph.adjacency_list + assert ordered(adjacency_list) == ordered( + { + "TestStep": ["TestStepCollection-0"], + "TestStep2": [], + "TestStepCollection-0": ["TestStepCollection-1"], + "TestStepCollection-1": ["TestStep2"], + } + ) + _verify_pipeline_graph_traversal(pipeline_graph) + + +def test_pipeline_graph_acyclic_with_step_collection_property_reference_dependency( + sagemaker_session_mock, +): + custom_step_collection = CustomStepCollection(name="TestStepCollection") + custom_step = CustomStep( + name="TestStep", + input_data=custom_step_collection.properties.AlgorithmSpecification.AlgorithmName, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step_collection, custom_step], + sagemaker_session=sagemaker_session_mock, + ) + + pipeline_graph = PipelineGraph.from_pipeline(pipeline) + adjacency_list = pipeline_graph.adjacency_list + assert ordered(adjacency_list) == ordered( + { + "TestStep": [], + "TestStepCollection-0": ["TestStepCollection-1"], + "TestStepCollection-1": ["TestStep"], + } + ) + _verify_pipeline_graph_traversal(pipeline_graph) + + +def test_pipeline_graph_cyclic(sagemaker_session_mock): + step_a = CustomStep(name="stepA", depends_on=["stepC"]) + step_b = CustomStep(name="stepB", depends_on=["stepA"]) + step_c = CustomStep(name="stepC", depends_on=["stepB"]) + + pipeline = Pipeline( + name="MyPipeline", steps=[step_a, step_b, step_c], sagemaker_session=sagemaker_session_mock + ) + + with pytest.raises(ValueError) as error: + PipelineGraph.from_pipeline(pipeline) + assert "Cycle detected in pipeline step graph." in str(error.value) + + +def test_condition_comparison(sagemaker_session): + param = ParameterInteger(name="MyInt") + cond = ConditionEquals(left=param, right=1) + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + cond_step = ConditionStep( + name="MyConditionStep", + conditions=[cond], + if_steps=[if_step], + else_steps=[else_step], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyConditionStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": []} + ) + + +def test_condition_not(sagemaker_session): + param = ParameterInteger(name="MyInt") + cond = ConditionEquals(left=param, right=1) + cond_not = ConditionNot(expression=cond) + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + cond_step = ConditionStep( + name="MyConditionStep", + conditions=[cond_not], + if_steps=[if_step], + else_steps=[else_step], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyConditionStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": []} + ) + + +def test_condition_in(sagemaker_session): + param = ParameterString(name="MyStr") + cond_in = ConditionIn(value=param, in_values=["abc", "def"]) + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + cond_step = ConditionStep( + name="MyConditionStep", + conditions=[cond_in], + if_steps=[if_step], + else_steps=[else_step], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyConditionStep": ["ElseStep", "IfStep"], + "ElseStep": [], + "IfStep": [], + } + ) + + +def test_condition_or(sagemaker_session): + param = ParameterString(name="MyStr") + cond1 = ConditionGreaterThan(left=ExecutionVariables.START_DATETIME, right="2020-12-01") + cond2 = ConditionEquals(left=param, right="Success") + cond_or = ConditionOr(conditions=[cond1, cond2]) + if_step = CustomStep(name="IfStep") + else_step = CustomStep(name="ElseStep") + cond_step = ConditionStep( + name="MyConditionStep", + conditions=[cond_or], + if_steps=[if_step], + else_steps=[else_step], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step], + sagemaker_session=sagemaker_session, + parameters=[param], + ) + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyConditionStep": ["ElseStep", "IfStep"], + "ElseStep": [], + "IfStep": [], + } + ) + + +def _verify_pipeline_graph_traversal(pipeline_graph): + adjacency_list = pipeline_graph.adjacency_list + traversed_steps = [] + for step in pipeline_graph: + # the traversal order of a PipelineGraph needs to be a topological sort traversal + # i.e. parent steps are always traversed before their children steps + assert step not in traversed_steps + for children_steps in adjacency_list[step.name]: + assert children_steps not in traversed_steps + traversed_steps.append(step) diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 8d6ee80389..e1b02c17d4 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -42,7 +42,7 @@ from sagemaker.workflow.steps import CacheConfig, ProcessingStep -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.properties import PropertyFile from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join @@ -59,7 +59,7 @@ ModelPredictedLabelConfig, SHAPConfig, ) -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered REGION = "us-west-2" BUCKET = "my-bucket" @@ -326,6 +326,14 @@ def test_processing_step_with_processor(pipeline_session, processing_input): assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "TestStep": ["MyProcessingStep"], + "SecondTestStep": ["MyProcessingStep"], + "MyProcessingStep": [], + } + ) def test_processing_step_with_processor_and_step_args(pipeline_session, processing_input): @@ -491,7 +499,7 @@ def shap_config(): seed=123, ) - def verfiy(step_args): + def verify(step_args): step = ProcessingStep( name="MyProcessingStep", step_args=step_args, @@ -531,13 +539,13 @@ def verfiy(step_args): bias_config=data_bias_config(), model_config=model_config("1st-model-rpyndy0uyo"), ) - verfiy(run_bias_args) + verify(run_bias_args) run_pre_training_bias_args = clarify_processor.run_pre_training_bias( data_config=data_config, data_bias_config=data_bias_config(), ) - verfiy(run_pre_training_bias_args) + verify(run_pre_training_bias_args) run_post_training_bias_args = clarify_processor.run_post_training_bias( data_config=data_config, @@ -545,14 +553,14 @@ def verfiy(step_args): model_config=model_config("1st-model-rpyndy0uyo"), model_predicted_label_config=ModelPredictedLabelConfig(probability_threshold=0.9), ) - verfiy(run_post_training_bias_args) + verify(run_post_training_bias_args) run_explainability_args = clarify_processor.run_explainability( data_config=data_config, model_config=model_config("1st-model-rpyndy0uyo"), explainability_config=shap_config(), ) - verfiy(run_explainability_args) + verify(run_explainability_args) @pytest.mark.parametrize( diff --git a/tests/unit/sagemaker/workflow/test_properties.py b/tests/unit/sagemaker/workflow/test_properties.py index f36bc7577d..67fdd919a0 100644 --- a/tests/unit/sagemaker/workflow/test_properties.py +++ b/tests/unit/sagemaker/workflow/test_properties.py @@ -18,7 +18,7 @@ def test_properties_describe_training_job_response(): - prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + prop = Properties(step_name="MyStep", shape_name="DescribeTrainingJobResponse") some_prop_names = ["TrainingJobName", "TrainingJobArn", "HyperParameters", "OutputDataConfig"] for name in some_prop_names: assert name in prop.__dict__.keys() @@ -30,7 +30,7 @@ def test_properties_describe_training_job_response(): def test_properties_describe_processing_job_response(): - prop = Properties("Steps.MyStep", "DescribeProcessingJobResponse") + prop = Properties(step_name="MyStep", shape_name="DescribeProcessingJobResponse") some_prop_names = ["ProcessingInputs", "ProcessingOutputConfig", "ProcessingEndTime"] for name in some_prop_names: assert name in prop.__dict__.keys() @@ -42,7 +42,7 @@ def test_properties_describe_processing_job_response(): def test_properties_tuning_job(): prop = Properties( - "Steps.MyStep", + step_name="MyStep", shape_names=[ "DescribeHyperParameterTuningJobResponse", "ListTrainingJobsForHyperParameterTuningJobResponse", @@ -72,7 +72,7 @@ def test_properties_tuning_job(): def test_properties_emr_step(): - prop = Properties("Steps.MyStep", "Step", service_name="emr") + prop = Properties("MyStep", shape_name="Step", service_name="emr") some_prop_names = ["Id", "Name", "Config", "ActionOnFailure", "Status"] for name in some_prop_names: assert name in prop.__dict__.keys() @@ -85,7 +85,7 @@ def test_properties_emr_step(): def test_properties_describe_model_package_output(): - prop = Properties("Steps.MyStep", "DescribeModelPackageOutput") + prop = Properties(step_name="MyStep", shape_name="DescribeModelPackageOutput") some_prop_names = ["ModelPackageName", "ModelPackageGroupName", "ModelPackageArn"] for name in some_prop_names: assert name in prop.__dict__.keys() @@ -96,7 +96,7 @@ def test_properties_describe_model_package_output(): def test_to_string(): - prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + prop = Properties("MyStep", shape_name="DescribeTrainingJobResponse") assert prop.CreationTime.to_string().expr == { "Std:Join": { @@ -107,7 +107,7 @@ def test_to_string(): def test_implicit_value(): - prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + prop = Properties("MyStep", shape_name="DescribeTrainingJobResponse") with pytest.raises(TypeError) as error: str(prop.CreationTime) @@ -123,8 +123,8 @@ def test_implicit_value(): def test_add_func(): - prop_train = Properties("Steps.MyStepTrain", "DescribeTrainingJobResponse") - prop_model = Properties("Steps.MyStepModel", "DescribeModelPackageOutput") + prop_train = Properties("MyStepTrain", shape_name="DescribeTrainingJobResponse") + prop_model = Properties("MyStepModel", shape_name="DescribeModelPackageOutput") with pytest.raises(TypeError) as error: prop_train + prop_model diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 751817fa44..1a61d2088b 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -46,7 +46,7 @@ from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer from sagemaker.workflow.functions import Join -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.properties import Properties, PropertyFile from sagemaker.workflow.parameters import ParameterString, ParameterInteger, ParameterBoolean from sagemaker.workflow.retry import ( @@ -70,6 +70,7 @@ from sagemaker.predictor import Predictor from sagemaker.model import FrameworkModel from tests.unit import DATA_DIR +from tests.unit.sagemaker.workflow.helpers import ordered DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -85,7 +86,7 @@ def __init__(self, name, display_name=None, description=None, retry_policies=Non super(CustomStep, self).__init__( name, StepTypeEnum.TRAINING, display_name, description, None, retry_policies ) - self._properties = Properties(path=f"Steps.{name}") + self._properties = Properties(name) @property def arguments(self): @@ -395,6 +396,14 @@ def test_training_step_base_estimator(sagemaker_session): } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} assert step.properties.HyperParameters.expr == {"Get": "Steps.MyTrainingStep.HyperParameters"} + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "AnotherTestStep": ["MyTrainingStep"], + "MyTrainingStep": [], + "TestStep": ["MyTrainingStep"], + } + ) def test_training_step_tensorflow(sagemaker_session): @@ -668,6 +677,15 @@ def test_processing_step(sagemaker_session): assert step.properties.ProcessingJobName.expr == { "Get": "Steps.MyProcessingStep.ProcessingJobName" } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "SecondTestStep": ["MyProcessingStep"], + "TestStep": ["MyProcessingStep"], + "ThirdTestStep": ["MyProcessingStep"], + "MyProcessingStep": [], + } + ) @patch("sagemaker.processing.ScriptProcessor._normalize_args") @@ -979,7 +997,7 @@ def test_transform_step(sagemaker_session): def test_properties_describe_training_job_response(): - prop = Properties("Steps.MyStep", "DescribeTrainingJobResponse") + prop = Properties(step_name="MyStep", shape_name="DescribeTrainingJobResponse") some_prop_names = ["TrainingJobName", "TrainingJobArn", "HyperParameters", "OutputDataConfig"] for name in some_prop_names: assert name in prop.__dict__.keys() @@ -990,7 +1008,7 @@ def test_properties_describe_training_job_response(): def test_properties_describe_processing_job_response(): - prop = Properties("Steps.MyStep", "DescribeProcessingJobResponse") + prop = Properties(step_name="MyStep", shape_name="DescribeProcessingJobResponse") some_prop_names = ["ProcessingInputs", "ProcessingOutputConfig", "ProcessingEndTime"] for name in some_prop_names: assert name in prop.__dict__.keys() diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 14df41c876..0c6a6e34df 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -27,7 +27,7 @@ from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.steps import TrainingStep -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.functions import Join from sagemaker.estimator import Estimator @@ -55,7 +55,7 @@ from tests.unit import DATA_DIR from sagemaker.inputs import TrainingInput -from tests.unit.sagemaker.workflow.helpers import CustomStep +from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered REGION = "us-west-2" BUCKET = "my-bucket" @@ -242,6 +242,10 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar "Arguments": step_args.args, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + {"MyTrainingStep": [], "SecondTestStep": ["MyTrainingStep"], "TestStep": ["MyTrainingStep"]} + ) @pytest.mark.parametrize("estimator", ESTIMATOR_LISTS) diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index 2a75349d56..3d0e25a2ee 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -26,7 +26,7 @@ from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.steps import TransformStep, TransformInput -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join from sagemaker.workflow import is_pipeline_variable @@ -174,6 +174,8 @@ def test_transform_step_with_transformer(model_name, data, output_path, pipeline step_def["Arguments"]["TransformOutput"]["S3OutputPath"], ) assert step_def == {"Name": "MyTransformStep", "Type": "Transform", "Arguments": step_args} + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert adjacency_list == {"MyTransformStep": []} @pytest.mark.parametrize( diff --git a/tests/unit/sagemaker/workflow/test_tuning_step.py b/tests/unit/sagemaker/workflow/test_tuning_step.py index af9ce57d0c..a39512d006 100644 --- a/tests/unit/sagemaker/workflow/test_tuning_step.py +++ b/tests/unit/sagemaker/workflow/test_tuning_step.py @@ -25,7 +25,7 @@ from sagemaker.workflow.steps import TuningStep from sagemaker.inputs import TrainingInput -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join @@ -34,7 +34,6 @@ from tests.unit import DATA_DIR - REGION = "us-west-2" BUCKET = "my-bucket" ROLE = "DummyRole" @@ -166,6 +165,8 @@ def test_tuning_step_with_single_algo_tuner(pipeline_session, training_input, en "Type": "Tuning", "Arguments": step_args, } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert adjacency_list == {"MyTuningStep": []} def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point): @@ -229,6 +230,8 @@ def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point): "Type": "Tuning", "Arguments": step_args.args, } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert adjacency_list == {"MyTuningStep": []} @pytest.mark.parametrize( diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 8ad3546bcc..e4eb05110c 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -184,7 +184,7 @@ def test_repack_model_step_with_invalid_input(): def test_repack_model_step_with_source_dir(estimator, source_dir): - model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput") + model_data = Properties(step_name="MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" step = _RepackModelStep( name="MyRepackModelStep", @@ -259,7 +259,7 @@ def test_inject_repack_script_s3(estimator, tmp, fake_s3): ], ) - model_data = Properties(path="Steps.MyStep", shape_name="DescribeModelOutput") + model_data = Properties(step_name="MyStep", shape_name="DescribeModelOutput") entry_point = "inference.py" source_dir_path = "s3://fake/location" step = _RepackModelStep( From 85738c4302886dbfa6e196f0d847bbd1f0dab392 Mon Sep 17 00:00:00 2001 From: Loki Date: Sun, 19 Jun 2022 20:25:09 -0700 Subject: [PATCH 089/526] fix: changing trcomp integ tests to be able to run in all regions (#3175) * fix: changing trcomp integ tests to be able to run in all regions * fix: broken fixture in trcomp integ test --- tests/integ/test_training_compiler.py | 61 +++++++++++++++++++++------ 1 file changed, 49 insertions(+), 12 deletions(-) diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 92d4dbbffb..f76894b6a6 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -31,8 +31,42 @@ def gpu_instance_type(request): return "ml.p3.2xlarge" +@pytest.fixture(scope="module") +def imagenet_val_set(request, sagemaker_session, tmpdir_factory): + """ + Copies the dataset from the bucket it's hosted in to the local bucket in the test region + """ + local_path = tmpdir_factory.mktemp("trcomp_imagenet_val_set") + sagemaker_session.download_data( + path=local_path, + bucket="collection-of-ml-datasets", + key_prefix="Imagenet/TFRecords/validation", + ) + train_input = sagemaker_session.upload_data( + path=local_path, + key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val", + ) + return train_input + + +@pytest.fixture(scope="module") +def huggingface_dummy_dataset(request, sagemaker_session): + """ + Copies the dataset from the local disk to the local bucket in the test region + """ + data_path = os.path.join(DATA_DIR, "huggingface") + train_input = sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), + key_prefix="integ-test-data/trcomp/huggingface/dummy/train", + ) + return train_input + + @pytest.fixture(scope="module", autouse=True) def skip_if_incompatible(request): + """ + These tests are for training compiler enabled images/estimators only. + """ if integ.test_region() not in integ.TRAINING_COMPILER_SUPPORTED_REGIONS: pytest.skip("SageMaker Training Compiler is not supported in this region") if integ.test_region() in integ.TRAINING_NO_P3_REGIONS: @@ -45,7 +79,11 @@ def test_huggingface_pytorch( gpu_instance_type, huggingface_training_compiler_latest_version, huggingface_training_compiler_pytorch_latest_version, + huggingface_dummy_dataset, ): + """ + Test the HuggingFace estimator with PyTorch + """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): data_path = os.path.join(DATA_DIR, "huggingface") @@ -73,12 +111,7 @@ def test_huggingface_pytorch( compiler_config=HFTrainingCompilerConfig(), ) - train_input = hf.sagemaker_session.upload_data( - path=os.path.join(data_path, "train"), - key_prefix="integ-test-data/huggingface/train", - ) - - hf.fit(train_input) + hf.fit(huggingface_dummy_dataset) @pytest.mark.release @@ -87,7 +120,11 @@ def test_huggingface_tensorflow( gpu_instance_type, huggingface_training_compiler_latest_version, huggingface_training_compiler_tensorflow_latest_version, + huggingface_dummy_dataset, ): + """ + Test the HuggingFace estimator with TensorFlow + """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): data_path = os.path.join(DATA_DIR, "huggingface") @@ -112,11 +149,7 @@ def test_huggingface_tensorflow( compiler_config=HFTrainingCompilerConfig(), ) - train_input = hf.sagemaker_session.upload_data( - path=os.path.join(data_path, "train"), key_prefix="integ-test-data/huggingface/train" - ) - - hf.fit(train_input) + hf.fit(huggingface_dummy_dataset) @pytest.mark.release @@ -124,7 +157,11 @@ def test_tensorflow( sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, + imagenet_val_set, ): + """ + Test the TensorFlow estimator + """ if version.parse(tensorflow_training_latest_version) < version.parse("2.9"): pytest.skip("Training Compiler only supports TF >= 2.9") with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): @@ -173,7 +210,7 @@ def test_tensorflow( ) tf.fit( - inputs="s3://collection-of-ml-datasets/Imagenet/TFRecords/validation", + inputs=imagenet_val_set, logs=True, wait=True, ) From 0b616d81ce89278d9a3e6c9f1a12e35965f383b6 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 20 Jun 2022 23:09:04 +0000 Subject: [PATCH 090/526] prepare release v2.96.0 --- CHANGELOG.md | 10 ++++++++++ VERSION | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fc2356c3f2..9c4ae95cf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v2.96.0 (2022-06-20) + +### Features + + * Add helper method to generate pipeline adjacency list + +### Bug Fixes and Other Changes + + * changing trcomp integ tests to be able to run in all regions + ## v2.95.0 (2022-06-16) ### Features diff --git a/VERSION b/VERSION index b5d39344d2..6bc2cdbbcc 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.95.1.dev0 +2.96.0 From a29be60a765e9ddb0e033f737845cd27471672e3 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 20 Jun 2022 23:09:05 +0000 Subject: [PATCH 091/526] update development version to v2.96.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 6bc2cdbbcc..0304e09ecb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.96.0 +2.96.1.dev0 From 8807eb7320e1802a3178da7f4e7194f26bd28499 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:51:38 -0700 Subject: [PATCH 092/526] deprecation: remove support for python 3.6 (#3170) * feature: support python 3.10 in tox * fix: update scipy * deprecation: remove support for python 3.6 --- .githooks/pre-push | 4 ++-- README.rst | 1 - requirements/extras/scipy_requirements.txt | 2 +- setup.py | 1 - tox.ini | 4 ++-- 5 files changed, 5 insertions(+), 7 deletions(-) diff --git a/.githooks/pre-push b/.githooks/pre-push index 719e35e88f..9235297add 100755 --- a/.githooks/pre-push +++ b/.githooks/pre-push @@ -12,5 +12,5 @@ start_time=`date +%s` tox -e sphinx,doc8 --parallel all ./ci-scripts/displaytime.sh 'sphinx,doc8' $start_time start_time=`date +%s` -tox -e py36,py37,py38,py39 --parallel all -- tests/unit -./ci-scripts/displaytime.sh 'py36,py37,py38,py39 unit' $start_time +tox -e py37,py38,py39 --parallel all -- tests/unit +./ci-scripts/displaytime.sh 'py37,py38,py39 unit' $start_time diff --git a/README.rst b/README.rst index ab62eddad0..44f724a781 100644 --- a/README.rst +++ b/README.rst @@ -87,7 +87,6 @@ Supported Python Versions SageMaker Python SDK is tested on: -- Python 3.6 - Python 3.7 - Python 3.8 - Python 3.9 diff --git a/requirements/extras/scipy_requirements.txt b/requirements/extras/scipy_requirements.txt index 9136ba9f03..1cf073e9f5 100644 --- a/requirements/extras/scipy_requirements.txt +++ b/requirements/extras/scipy_requirements.txt @@ -1 +1 @@ -scipy==1.5.4 +scipy==1.7.2 diff --git a/setup.py b/setup.py index 780cbaed04..75a113f019 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,6 @@ def read_requirements(filename): "Natural Language :: English", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", diff --git a/tox.ini b/tox.ini index 8822f897a2..47d23b0b89 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [tox] -envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py36,py37,py38,py39 +envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py37,py38,py39,py310 skip_missing_interpreters = False @@ -74,7 +74,7 @@ commands = {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test] depends = - {py36,py37,py38,py39}: clean + {py37,py38,py39,py310}: clean [testenv:flake8] skipdist = true From 41df39cf3110b665d3c8bf71bdfd62c484ad8ae9 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 22 Jun 2022 12:16:05 -0700 Subject: [PATCH 093/526] change: Add override_pipeline_parameter_var decorator to give grace period to update invalid pipeline var args (#3180) Co-authored-by: Dewen Qi --- src/sagemaker/image_uris.py | 3 + src/sagemaker/processing.py | 4 +- src/sagemaker/workflow/utilities.py | 52 ++++++++++++++++- .../sagemaker/image_uris/test_retrieve.py | 57 +++++++++++++++---- 4 files changed, 101 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index c0d616969c..ec1fec2d20 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -24,6 +24,7 @@ from sagemaker.spark import defaults from sagemaker.jumpstart import artifacts from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.utilities import override_pipeline_parameter_var logger = logging.getLogger(__name__) @@ -31,6 +32,8 @@ HUGGING_FACE_FRAMEWORK = "huggingface" +# TODO: we should remove this decorator later +@override_pipeline_parameter_var def retrieve( framework, region, diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 30acbddc55..cebe25dbab 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1195,7 +1195,9 @@ def __init__( source (str): The source for the output. destination (str): The destination of the output. If a destination is not provided, one will be generated: - "s3:////output/". + "s3:////output/" + (Note: this does not apply when used with + :class:`~sagemaker.workflow.steps.ProcessingStep`). output_name (str): The name of the output. If a name is not provided, one will be generated (eg. "output-1"). s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous". diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 93bf9343e2..afe1e4eae1 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -13,23 +13,29 @@ """Utilities to support workflow.""" from __future__ import absolute_import +import inspect +import logging +from functools import wraps from pathlib import Path -from typing import List, Sequence, Union, Set +from typing import List, Sequence, Union, Set, TYPE_CHECKING import hashlib from urllib.parse import unquote, urlparse from _hashlib import HASH as Hash +from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.pipeline_context import _StepArguments -from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.entities import ( Entity, RequestType, ) +if TYPE_CHECKING: + from sagemaker.workflow.step_collections import StepCollection + BUF_SIZE = 65536 # 64KiB -def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[RequestType]: +def list_to_request(entities: Sequence[Union[Entity, "StepCollection"]]) -> List[RequestType]: """Get the request structure for list of entities. Args: @@ -37,6 +43,8 @@ def list_to_request(entities: Sequence[Union[Entity, StepCollection]]) -> List[R Returns: list: A request structure for a workflow service call. """ + from sagemaker.workflow.step_collections import StepCollection + request_dicts = [] for entity in entities: if isinstance(entity, Entity): @@ -151,3 +159,41 @@ def validate_step_args_input( raise TypeError(error_message) if step_args.caller_name not in expected_caller: raise ValueError(error_message) + + +def override_pipeline_parameter_var(func): + """A decorator to override pipeline Parameters passed into a function + + This is a temporary decorator to override pipeline Parameter objects with their default value + and display warning information to instruct users to update their code. + + This decorator can help to give a grace period for users to update their code when + we make changes to explicitly prevent passing any pipeline variables to a function. + + We should remove this decorator after the grace period. + """ + warning_msg_template = ( + "%s should not be a pipeline variable (%s). " + "The default_value of this Parameter object will be used to override it. " + "Please remove this pipeline variable and use python primitives instead." + ) + + @wraps(func) + def wrapper(*args, **kwargs): + params = inspect.signature(func).parameters + args = list(args) + for i, (arg_name, _) in enumerate(params.items()): + if i >= len(args): + break + if isinstance(args[i], Parameter): + logging.warning(warning_msg_template, arg_name, type(args[i])) + args[i] = args[i].default_value + args = tuple(args) + + for arg_name, value in kwargs.items(): + if isinstance(value, Parameter): + logging.warning(warning_msg_template, arg_name, type(value)) + kwargs[arg_name] = value.default_value + return func(*args, **kwargs) + + return wrapper diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index ddf4049448..c167da6f47 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -19,6 +19,7 @@ from mock import patch from sagemaker import image_uris +from sagemaker.workflow.functions import Join from sagemaker.workflow.parameters import ParameterString BASE_CONFIG = { @@ -721,16 +722,50 @@ def test_retrieve_huggingface(config_for_framework): def test_retrieve_with_pipeline_variable(): + kwargs = dict( + framework="tensorflow", + version="1.15", + py_version="py3", + instance_type="ml.m5.xlarge", + region="us-east-1", + image_scope="training", + ) + # instance_type is plain string which should not break anything + image_uris.retrieve(**kwargs) + + # instance_type is parameter string with not None default value + # which should not break anything + kwargs["instance_type"] = ParameterString( + name="TrainingInstanceType", + default_value="ml.m5.xlarge", + ) + image_uris.retrieve(**kwargs) + + # instance_type is parameter string without default value + # (equivalent to pass in None to instance_type field) + # which should fail due to empty instance type check + kwargs["instance_type"] = ParameterString(name="TrainingInstanceType") with pytest.raises(Exception) as error: - image_uris.retrieve( - framework="tensorflow", - version="1.15", - py_version="py3", - instance_type=ParameterString( - name="TrainingInstanceType", - default_value="ml.m5.xlarge", - ), - region="us-east-1", - image_scope="training", - ) + image_uris.retrieve(**kwargs) + assert "Empty SageMaker instance type" in str(error.value) + + # instance_type is other types of pipeline variable + # which should break loudly + kwargs["instance_type"] = Join(on="", values=["a", "b"]) + with pytest.raises(Exception) as error: + image_uris.retrieve(**kwargs) assert "instance_type should not be a pipeline variable" in str(error.value) + + # instance_type (ParameterString) is given as args rather than kwargs + # which should not break anything + image_uris.retrieve( + "tensorflow", + "us-east-1", + "1.15", + "py3", + ParameterString( + name="TrainingInstanceType", + default_value="ml.m5.xlarge", + ), + image_scope="training", + ) From 19d824684be84e3d4364d81ccc9ff6273bda01d0 Mon Sep 17 00:00:00 2001 From: Ben Crabtree Date: Wed, 22 Jun 2022 15:36:33 -0400 Subject: [PATCH 094/526] feat: update prebuilt models documentation (#3186) --- doc/doc_utils/jumpstart_doc_utils.py | 21 ++- .../{jumpstart.rst => pretrainedmodels.rst} | 0 doc/overview.rst | 130 ++++++++---------- 3 files changed, 75 insertions(+), 76 deletions(-) rename doc/doc_utils/{jumpstart.rst => pretrainedmodels.rst} (100%) diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index d2658dca30..94096fbf1d 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -143,20 +143,26 @@ def create_jumpstart_model_table(): file_content.append(".. |external-link| raw:: html\n\n") file_content.append(' \n\n') - file_content.append("==================================\n") - file_content.append("JumpStart Available Model Table\n") - file_content.append("==================================\n") + file_content.append("================================================\n") + file_content.append("Built-in Algorithms with pre-trained Model Table\n") + file_content.append("================================================\n") file_content.append( """ - JumpStart for the SageMaker Python SDK uses model IDs and model versions to access the necessary - utilities. This table serves to provide the core material plus some extra information that can be useful - in selecting the correct model ID and corresponding parameters.\n""" + The SageMaker Python SDK uses model IDs and model versions to access the necessary + utilities for pre-trained models. This table serves to provide the core material plus + some extra information that can be useful in selecting the correct model ID and + corresponding parameters.\n""" ) file_content.append( """ If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. We highly suggest pinning an exact model version however.\n""" ) + file_content.append( + """ + These models are also available through the + `JumpStart UI in SageMaker Studio `__\n""" + ) file_content.append("\n") file_content.append(".. list-table:: Available Models\n") file_content.append(" :widths: 50 20 20 20 30 20\n") @@ -183,5 +189,6 @@ def create_jumpstart_model_table(): " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) - f = open("doc_utils/jumpstart.rst", "w") + f = open("doc_utils/pretrainedmodels.rst", "w") f.writelines(file_content) + f.close() diff --git a/doc/doc_utils/jumpstart.rst b/doc/doc_utils/pretrainedmodels.rst similarity index 100% rename from doc/doc_utils/jumpstart.rst rename to doc/doc_utils/pretrainedmodels.rst diff --git a/doc/overview.rst b/doc/overview.rst index 52c942b47b..14b7d47cda 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -573,24 +573,31 @@ Here is an example: # When you are done using your endpoint model.sagemaker_session.delete_endpoint('my-endpoint') -********************************************************* -Use SageMaker JumpStart Algorithms with Pretrained Models -********************************************************* +*********************************************************************** +Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK +*********************************************************************** + +SageMaker Python SDK provides built-in algorithms with pre-trained models from popular open source model +hubs, such as TensorFlow Hub, Pytorch Hub, and HuggingFace. Customer can deploy these pre-trained models +as-is or first fine-tune them on a custom dataset and then deploy to a SageMaker endpoint for inference. + + +SageMaker SDK built-in algorithms allow customers access pre-trained models using model ids and model +versions. The ‘pre-trained model’ table below provides list of models with information useful in +selecting the correct model id and corresponding parameters. These models are also available through +the `JumpStart UI in SageMaker Studio `__. -JumpStart for the SageMaker Python SDK uses model ids and model versions to access the necessary -utilities. This table serves to provide the core material plus some extra information that can be useful -in selecting the correct model id and corresponding parameters. .. toctree:: :maxdepth: 2 - doc_utils/jumpstart + doc_utils/pretrainedmodels Example notebooks ================= -JumpStart supports 15 different machine learning problem types. Below is a list of all the supported -problem types with a link to a Jupyter notebook that provides example usage. +SageMaker built-in algorithms with pre-trained models support 15 different machine learning problem types. +Below is a list of all the supported problem types with a link to a Jupyter notebook that provides example usage. Vision - `Image Classification `__ @@ -610,25 +617,15 @@ Text - `Text Embedding `__ Tabular - - `Tabular Classification (LightGBM & Catboost) `__ - - `Tabular Classification (XGBoost & Linear Learner) `__ - - `Tabular Regression (LightGBM & Catboost) `__ - - `Tabular Regression (XGBoost & Linear Learner) `__ - - -`Amazon SageMaker JumpStart `__ is a -SageMaker feature that helps users bring machine learning (ML) -applications to market using prebuilt solutions for common use cases, -example notebooks, open source models from model zoos, and built-in -algorithms. - -A JumpStart model enables you to quickly start a machine learning -workflow. JumpStart takes models from popular open source model hubs, -such as TensorFlow and HuggingFace, and pre-trains them on an open -source dataset. Using the SageMaker Python SDK, you can select a -prebuilt model from the model zoo to train on custom data or deploy -to a SageMaker endpoint for inference without signing up for -SageMaker Studio. + - `Tabular Classification (LightGBM & Catboost) `__ + - `Tabular Classification (XGBoost & Scikit-learn Linear Learner) `__ + - `Tabular Classification (AutoGluon) `__ + - `Tabular Classification (TabTransformer) `__ + - `Tabular Regression (LightGBM & Catboost) `__ + - `Tabular Regression (XGBoost & Scikit-learn Linear Learner) `__ + - `Tabular Regression (AutoGluon) `__ + - `Tabular Regression (TabTransformer) `__ + The following topic give you information about JumpStart components, as well as how to use the SageMaker Python SDK for these workflows. @@ -644,24 +641,22 @@ Prerequisites Amazon S3. For more information about IAM role permissions, see `Policies and permissions in IAM `__. -JumpStart Components -==================== +Built-in Components +=================== -The following sections give information about the main JumpStart +The following sections give information about the main built-in components and their function. -JumpStart models ----------------- +Pre-trained models +------------------ -JumpStart maintains a model zoo of over 300 models pre-trained on -open source datasets. You can use the SageMaker Python SDK -to fine-tune a model on your own dataset or deploy it directly to a -SageMaker endpoint for inference. +SageMaker maintains a model zoo of over 300 models from popular open source model hubs, such as +TensorFlow Hub, Pytorch Hub, and HuggingFace. You can use the SageMaker Python SDK to fine-tune +a model on your own dataset or deploy it directly to a SageMaker endpoint for inference. -JumpStart model artifacts are stored as tarballs in the JumpStart S3 -bucket. Each model is versioned and contains a unique ID which can be -used to retrieve the model URI. The following information describes -the ``model_id`` and ``model_version`` needed to retrieve the URI. +Model artifacts are stored as tarballs in a S3 bucket. Each model is versioned and contains a +unique ID which can be used to retrieve the model URI. The following information describes the +``model_id`` and ``model_version`` needed to retrieve the URI. .. container:: @@ -671,7 +666,7 @@ the ``model_id`` and ``model_version`` needed to retrieve the URI. required parameter. To retrieve a model, first select a ``model ID`` and ``version`` from -the :doc:`available models <./doc_utils/jumpstart>`. +the :doc:`available models <./doc_utils/pretrainedmodels>`. .. code:: python @@ -688,15 +683,13 @@ Then use those values to retrieve the model as follows.     model_id=model_id, model_version=model_version, model_scope=scope ) -JumpStart scripts ------------------ +Model scripts +------------- -To adapt JumpStart models for SageMaker, a custom -script is needed to perform training or inference. JumpStart -maintains a suite of scripts used for each of the models in the -JumpStart S3 bucket, which can be accessed using the SageMaker Python -SDK. Use the ``model_id`` and ``version`` of the corresponding model -to retrieve the related script as follows. +To adapt pre-trained models for SageMaker, a custom script is needed to perform training +or inference. SageMaker maintains a suite of scripts used for each of the models in the +S3 bucket, which can be accessed using the SageMaker Python SDK Use the ``model_id`` and +``version`` of the corresponding model to retrieve the related script as follows. .. code:: python @@ -706,11 +699,11 @@ to retrieve the related script as follows.     model_id=model_id, model_version=model_version, script_scope=scope ) -JumpStart images ----------------- +Model images +------------- A Docker image is required to perform training or inference on all -SageMaker models. JumpStart relies on Docker images from the +SageMaker models. SageMaker relies on Docker images from the following repos https://github.com/aws/deep-learning-containers, https://github.com/aws/sagemaker-xgboost-container, and https://github.com/aws/sagemaker-scikit-learn-container. Use @@ -733,16 +726,16 @@ retrieve the related image as follows. Deploy a  Pre-Trained Model Directly to a SageMaker Endpoint ============================================================ -In this section, you learn how to take a pre-trained JumpStart model -and deploy it directly to a SageMaker Endpoint. This is the fastest -way to start machine learning with a JumpStart model. The following +In this section, you learn how to take a pre-trained model and deploy +it directly to a SageMaker Endpoint. This is the fastest way to start +machine learning with a pre-trained model. The following assumes familiarity with `SageMaker models `__ and their deploy functions. -To begin, select a ``model_id`` and ``version`` from the JumpStart +To begin, select a ``model_id`` and ``version`` from the pre-trained models table, as well as a model scope of either “inference” or -“training”. For this example, you use a pre-trained JumpStart model, +“training”. For this example, you use a pre-trained model, so select “inference”  for your model scope. Use the utility functions to retrieve the URI of each of the three components you need to continue. @@ -772,7 +765,7 @@ need to continue. Next, pass the URIs and other key parameters as part of a new SageMaker Model class. The ``entry_point`` is a JumpStart script -named ``inference.py``. JumpStart handles the implementation of this +named ``inference.py``. SageMaker handles the implementation of this script. You must use this value for model inference to be successful. For more information about the Model class and its parameters, see `Model `__. @@ -811,7 +804,7 @@ Deployment may take about 5 minutes. Because the model and script URIs are distributed by SageMaker JumpStart, the endpoint, endpoint config and model resources will be prefixed with ``sagemaker-jumpstart``. Refer to the model ``Tags`` to inspect the -JumpStart artifacts involved in the model creation. +model artifacts involved in the model creation. Perform Inference ----------------- @@ -829,17 +822,16 @@ the Fine-tune a Model and Deploy to a SageMaker Endpoint ==================================================== -In this section, you initiate a training job to further train one of -the pretrained JumpStart models for your use case, then deploy it to -a SageMaker Endpoint for inference. This lets you fine tune the model -for your use case with your custom dataset. The following assumes +In this section, you initiate a training job to further train one of the pre-trained models +for your use case, then deploy it to a SageMaker Endpoint for inference. This lets you fine +tune the model for your use case with your custom dataset. The following assumes familiarity with `SageMaker training jobs and their architecture `__. -Fine-tune a JumpStart Model on a Custom Dataset ------------------------------------------------ +Fine-tune a Pre-trained Model on a Custom Dataset +------------------------------------------------- -To begin, select a ``model_id`` and ``version`` from the JumpStart +To begin, select a ``model_id`` and ``version`` from the pre-trained models table, as well as a model scope. In this case, you begin by using “training” as the model scope. Use the utility functions to retrieve the URI of each of the three components you need to @@ -875,10 +867,10 @@ Table `__ and selec     instance_type=training_instance_type, ) -Next, use the JumpStart resource URIs to create an ``Estimator`` and +Next, use the model resource URIs to create an ``Estimator`` and train it on a custom training dataset. You must specify the S3 path of your custom training dataset. The Estimator class requires -an ``entry_point`` parameter. In this case, JumpStart uses +an ``entry_point`` parameter. In this case, SageMaker uses “transfer_learning.py”. The training job fails to execute if this value is not set. From 8f0873bd456a633d5496664e77eb361f2616e3aa Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Thu, 23 Jun 2022 11:02:10 -0700 Subject: [PATCH 095/526] fix: update pytest, skip hf integ temp (#3193) * fix: update pytest * change: temporarily skip stanford test --- requirements/extras/test_requirements.txt | 2 +- tests/integ/test_huggingface.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index b002686404..2247394441 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -1,6 +1,6 @@ tox==3.24.5 flake8==4.0.1 -pytest==6.0.2 +pytest==6.2.5 pytest-cov==3.0.0 pytest-rerunfailures==10.2 pytest-timeout==2.1.0 diff --git a/tests/integ/test_huggingface.py b/tests/integ/test_huggingface.py index 5e9151fe86..3478ecea46 100644 --- a/tests/integ/test_huggingface.py +++ b/tests/integ/test_huggingface.py @@ -116,6 +116,9 @@ def test_huggingface_training( and integ.test_region() in integ.TRAINING_NO_P3_REGIONS, reason="no ml.p2 or ml.p3 instances in this region", ) +@pytest.mark.skip( + reason="need to re enable it later t.corp:V609860141", +) def test_huggingface_training_tf( sagemaker_session, gpu_instance_type, From 03ebd45dd30916838a0ddf5a80f1d517e675d2da Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Mon, 27 Jun 2022 16:21:24 -0700 Subject: [PATCH 096/526] change: Update model name from 'compiled.pt' to 'model.pth' for neo (#3198) Co-authored-by: Navin Soni --- tests/data/pytorch_neo/code/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 984005039f..5b89c2bebc 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "compiled.pt" - model = torch.jit.load(os.path.join(model_dir, "compiled.pt"), map_location=device) + # The compiled model is saved as "model.pth" + model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl") From 0eac3337d4e6c518342f1935e9eb96666156a683 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Mon, 27 Jun 2022 22:27:28 -0700 Subject: [PATCH 097/526] change: Skipping test_candidate_estimator_default_rerun_and_deploy (#3199) Co-authored-by: Navin Soni --- tests/integ/test_auto_ml.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integ/test_auto_ml.py b/tests/integ/test_auto_ml.py index 132b392160..3b4f46c9ba 100644 --- a/tests/integ/test_auto_ml.py +++ b/tests/integ/test_auto_ml.py @@ -293,6 +293,9 @@ def test_deploy_best_candidate(sagemaker_session, cpu_instance_type): tests.integ.test_region() in tests.integ.NO_AUTO_ML_REGIONS, reason="AutoML is not supported in the region yet.", ) +@pytest.mark.skip( + reason="", +) def test_candidate_estimator_default_rerun_and_deploy(sagemaker_session, cpu_instance_type): auto_ml_utils.create_auto_ml_job_if_not_exist(sagemaker_session) From 5e650aaedeb00649e7b0675c008546d9fd1e07d8 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 28 Jun 2022 20:59:58 +0000 Subject: [PATCH 098/526] prepare release v2.97.0 --- CHANGELOG.md | 17 +++++++++++++++++ VERSION | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c4ae95cf6..e825ecd918 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## v2.97.0 (2022-06-28) + +### Deprecations and Removals + + * remove support for python 3.6 + +### Features + + * update prebuilt models documentation + +### Bug Fixes and Other Changes + + * Skipping test_candidate_estimator_default_rerun_and_deploy + * Update model name from 'compiled.pt' to 'model.pth' for neo + * update pytest, skip hf integ temp + * Add override_pipeline_parameter_var decorator to give grace period to update invalid pipeline var args + ## v2.96.0 (2022-06-20) ### Features diff --git a/VERSION b/VERSION index 0304e09ecb..e362c0412b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.96.1.dev0 +2.97.0 From a359ad2eb4add2c3493993235319785823ca37e5 Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 28 Jun 2022 20:59:59 +0000 Subject: [PATCH 099/526] update development version to v2.97.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index e362c0412b..3f042a492c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.97.0 +2.97.1.dev0 From 5c604bf47e3a4b20c76b8b36caefaad63a62ad25 Mon Sep 17 00:00:00 2001 From: Basil Beirouti Date: Tue, 28 Jun 2022 15:17:39 -0700 Subject: [PATCH 100/526] documentation: edit to clarify how to use inference.py (#3194) Co-authored-by: Basil Beirouti --- doc/frameworks/tensorflow/using_tf.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/frameworks/tensorflow/using_tf.rst b/doc/frameworks/tensorflow/using_tf.rst index bd6cd36dcf..1e51b5f43a 100644 --- a/doc/frameworks/tensorflow/using_tf.rst +++ b/doc/frameworks/tensorflow/using_tf.rst @@ -759,7 +759,7 @@ Create Python Scripts for Custom Input and Output Formats --------------------------------------------------------- You can add your customized Python code to process your input and output data. -This customized Python code must be named ``inference.py`` and specified through the ``entry_point`` parameter: +This customized Python code must be named ``inference.py`` and is specified through the ``entry_point`` parameter: .. code:: @@ -769,6 +769,8 @@ This customized Python code must be named ``inference.py`` and specified through model_data='s3://mybucket/model.tar.gz', role='MySageMakerRole') +In the example above, ``inference.py`` is assumed to be a file inside ``model.tar.gz``. If you want to use a local file instead, you must add the ``source_dir`` argument. See the documentation on `TensorFlowModel `_. + How to implement the pre- and/or post-processing handler(s) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 134b3fd9021e996735eb83d47ebcacc1fce62b64 Mon Sep 17 00:00:00 2001 From: Roald Bradley Severtson Date: Thu, 30 Jun 2022 11:50:07 -0700 Subject: [PATCH 101/526] feature: Adding deepar image (#3203) --- src/sagemaker/image_uri_config/forecasting-deepar.json | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 4524887dee..b63cb1a99f 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -7,6 +7,7 @@ "ap-east-1": "286214385809", "ap-northeast-1": "633353088612", "ap-northeast-2": "204372634319", + "ap-northeast-3": "867004704886", "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "514117268639", From 88d4b20a8edf7da8196b6ffc52fd2ea26f0ee55b Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 5 Jul 2022 18:31:46 +0000 Subject: [PATCH 102/526] prepare release v2.98.0 --- CHANGELOG.md | 10 ++++++++++ VERSION | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e825ecd918..d2e8a82d85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## v2.98.0 (2022-07-05) + +### Features + + * Adding deepar image + +### Documentation Changes + + * edit to clarify how to use inference.py + ## v2.97.0 (2022-06-28) ### Deprecations and Removals diff --git a/VERSION b/VERSION index 3f042a492c..7cdfdfa411 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.97.1.dev0 +2.98.0 From 4d6df62d37d2c1d38d3e95fc91936eb4eab94b4f Mon Sep 17 00:00:00 2001 From: ci Date: Tue, 5 Jul 2022 18:31:47 +0000 Subject: [PATCH 103/526] update development version to v2.98.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7cdfdfa411..06d7460288 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.98.0 +2.98.1.dev0 From 7406766ec852be39c116b42dfad6f73d39da0fca Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Wed, 6 Jul 2022 12:57:37 -0400 Subject: [PATCH 104/526] fix: model table link (#3211) --- src/sagemaker/jumpstart/cache.py | 2 +- tests/unit/sagemaker/jumpstart/test_cache.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index a53eefc9e6..ac1ed5a17f 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -215,7 +215,7 @@ def _get_manifest_key_from_model_id_semantic_version( error_msg = f"Unable to find model manifest for '{model_id}' with version '{version}'. " error_msg += ( - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html" + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html" " for updated list of models. " ) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index 93e8114185..f87820114d 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -163,7 +163,7 @@ def test_jumpstart_cache_get_header(): ) assert ( "Unable to find model manifest for 'pytorch-ic-imagenet-inception-v3-classification-4' with " - "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "version '3.*'. Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for updated list of models. Consider using model ID 'pytorch-ic-imagenet-inception-v3-" "classification-4' with version '2.0.0'." ) in str(e.value) @@ -172,7 +172,7 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="pytorch-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'pytorch-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for updated list of models. " "Did you mean to use model ID 'pytorch-ic-imagenet-inception-v3-classification-4'?" ) in str(e.value) @@ -181,7 +181,7 @@ def test_jumpstart_cache_get_header(): cache.get_header(model_id="tensorflow-ic-", semantic_version_str="*") assert ( "Unable to find model manifest for 'tensorflow-ic-' with version '*'. " - "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html " + "Visit https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html " "for updated list of models. " "Did you mean to use model ID 'tensorflow-ic-imagenet-inception-" "v3-classification-4'?" From ae7b3728658798d34fd981d6fd64c43670d0a1b3 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 6 Jul 2022 10:25:03 -0700 Subject: [PATCH 105/526] change: Add PipelineVariable annotation in estimatory, processing, tuner, transformer base classes (#3182) Co-authored-by: Dewen Qi --- src/sagemaker/estimator.py | 203 ++++++++++-------- src/sagemaker/network.py | 12 +- src/sagemaker/parameter.py | 9 +- src/sagemaker/processing.py | 188 ++++++++-------- src/sagemaker/transformer.py | 59 ++--- src/sagemaker/tuner.py | 50 +++-- .../sagemaker/workflow/test_training_step.py | 11 +- 7 files changed, 291 insertions(+), 241 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 208dc208b4..f31cfd938d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -18,7 +18,7 @@ import os import uuid from abc import ABCMeta, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Union, Optional, List from six import string_types, with_metaclass from six.moves.urllib.parse import urlparse @@ -36,6 +36,7 @@ TensorBoardOutputConfig, get_default_profiler_rule, get_rule_container_image_uri, + RuleBase, ) from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs from sagemaker.fw_utils import ( @@ -46,7 +47,7 @@ tar_and_upload_dir, validate_source_dir, ) -from sagemaker.inputs import TrainingInput +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -75,6 +76,7 @@ name_from_base, ) from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import ( PipelineSession, runnable_by_pipeline, @@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man def __init__( self, - role, - instance_count=None, - instance_type=None, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - rules=None, - debugger_hook_config=None, - tensorboard_output_config=None, - enable_sagemaker_metrics=None, - enable_network_isolation=False, - profiler_config=None, - disable_profiler=False, - environment=None, - max_retry_attempts=None, - source_dir=None, - git_config=None, - hyperparameters=None, - container_log_level=logging.INFO, - code_location=None, - entry_point=None, - dependencies=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Optional[Union[int, PipelineVariable]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + rules: Optional[List[RuleBase]] = None, + debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + profiler_config: Optional[ProfilerConfig] = None, + disable_profiler: bool = False, + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, + source_dir: Optional[str] = None, + git_config: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + dependencies: Optional[List[Union[str]]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self): return None @runnable_by_pipeline - def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None): + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: str = "All", + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + ): """Train a model using the input training dataset. The API calls the Amazon SageMaker CreateTrainingJob API to start @@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config): ) train_args["input_mode"] = inputs.config["InputMode"] + # enable_network_isolation may be a pipeline variable place holder object + # which is parsed in execution time if estimator.enable_network_isolation(): - train_args["enable_network_isolation"] = True + train_args["enable_network_isolation"] = estimator.enable_network_isolation() if estimator.max_retry_attempts is not None: train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts} else: train_args["retry_strategy"] = None + # encrypt_inter_container_traffic may be a pipeline variable place holder object + # which is parsed in execution time if estimator.encrypt_inter_container_traffic: - train_args["encrypt_inter_container_traffic"] = True + train_args[ + "encrypt_inter_container_traffic" + ] = estimator.encrypt_inter_container_traffic if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator): train_args["algorithm_arn"] = estimator.algorithm_arn @@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase): def __init__( self, - image_uri, - role, - instance_count=None, - instance_type=None, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - hyperparameters=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - enable_network_isolation=False, - rules=None, - debugger_hook_config=None, - tensorboard_output_config=None, - enable_sagemaker_metrics=None, - profiler_config=None, - disable_profiler=False, - environment=None, - max_retry_attempts=None, - source_dir=None, - git_config=None, - container_log_level=logging.INFO, - code_location=None, - entry_point=None, - dependencies=None, + image_uri: Union[str, PipelineVariable], + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Optional[Union[int, PipelineVariable]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + rules: Optional[List[RuleBase]] = None, + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, + profiler_config: Optional[ProfilerConfig] = None, + disable_profiler: bool = False, + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, + source_dir: Optional[str] = None, + git_config: Optional[Dict[str, str]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2488,18 +2503,18 @@ class Framework(EstimatorBase): def __init__( self, - entry_point, - source_dir=None, - hyperparameters=None, - container_log_level=logging.INFO, - code_location=None, - image_uri=None, - dependencies=None, - enable_network_isolation=False, - git_config=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - enable_sagemaker_metrics=None, + entry_point: str, + source_dir: Optional[str] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + dependencies: Optional[List[str]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + git_config: Optional[Dict[str, str]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Base class initializer. diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index 1d2ae8c6ca..b3bf72a95a 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -16,6 +16,10 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List + +from sagemaker.workflow.entities import PipelineVariable + class NetworkConfig(object): """Accepts network configuration parameters for conversion to request dict. @@ -25,10 +29,10 @@ class NetworkConfig(object): def __init__( self, - enable_network_isolation=False, - security_group_ids=None, - subnets=None, - encrypt_inter_container_traffic=None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, ): """Initialize a ``NetworkConfig`` instance. diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 52efdeb7c6..79bbc62da2 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -14,8 +14,10 @@ from __future__ import absolute_import import json +from typing import Union from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable class ParameterRange(object): @@ -27,7 +29,12 @@ class ParameterRange(object): __all_types__ = ("Continuous", "Categorical", "Integer") - def __init__(self, min_value, max_value, scaling_type="Auto"): + def __init__( + self, + min_value: Union[int, float, PipelineVariable], + max_value: Union[int, float, PipelineVariable], + scaling_type: Union[str, PipelineVariable] = "Auto", + ): """Initialize a parameter range. Args: diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index cebe25dbab..1e4cfae4ff 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -22,7 +22,7 @@ import pathlib import logging from textwrap import dedent -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import attr @@ -31,10 +31,12 @@ from sagemaker import s3 from sagemaker.job import _Job from sagemaker.local import LocalSession +from sagemaker.network import NetworkConfig from sagemaker.utils import base_name_from_image, get_config_value, name_from_base from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.entities import PipelineVariable from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition from sagemaker.apiutils._base_types import ApiObject from sagemaker.s3 import S3Uploader @@ -47,20 +49,20 @@ class Processor(object): def __init__( self, - role, - image_uri, - instance_count, - instance_type, - entrypoint=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + image_uri: Union[str, PipelineVariable], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``Processor`` instance. @@ -133,14 +135,14 @@ def __init__( @runnable_by_pipeline def run( self, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + inputs: Optional[List["ProcessingInput"]] = None, + outputs: Optional[List["ProcessingOutput"]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -388,20 +390,20 @@ class ScriptProcessor(Processor): def __init__( self, - role, - image_uri, - command, - instance_count, - instance_type, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + image_uri: Union[str, PipelineVariable], + command: List[str], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``ScriptProcessor`` instance. @@ -498,15 +500,15 @@ def get_run_args( @runnable_by_pipeline def run( self, - code, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + code: str, + inputs: Optional[List["ProcessingInput"]] = None, + outputs: Optional[List["ProcessingOutput"]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -537,6 +539,8 @@ def run( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). """ normalized_inputs, normalized_outputs = self._normalize_args( job_name=job_name, @@ -1072,16 +1076,16 @@ class ProcessingInput(object): def __init__( self, - source=None, - destination=None, - input_name=None, - s3_data_type="S3Prefix", - s3_input_mode="File", - s3_data_distribution_type="FullyReplicated", - s3_compression_type="None", - s3_input=None, - dataset_definition=None, - app_managed=False, + source: Optional[Union[str, PipelineVariable]] = None, + destination: Optional[Union[str, PipelineVariable]] = None, + input_name: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + s3_input_mode: Union[str, PipelineVariable] = "File", + s3_data_distribution_type: Union[str, PipelineVariable] = "FullyReplicated", + s3_compression_type: Union[str, PipelineVariable] = "None", + s3_input: Optional[S3Input] = None, + dataset_definition: Optional[DatasetDefinition] = None, + app_managed: Union[bool, PipelineVariable] = False, ): """Initializes a ``ProcessingInput`` instance. @@ -1179,12 +1183,12 @@ class ProcessingOutput(object): def __init__( self, - source=None, - destination=None, - output_name=None, - s3_upload_mode="EndOfJob", - app_managed=False, - feature_store_output=None, + source: Optional[Union[str, PipelineVariable]] = None, + destination: Optional[Union[str, PipelineVariable]] = None, + output_name: Optional[Union[str, PipelineVariable]] = None, + s3_upload_mode: Union[str, PipelineVariable] = "EndOfJob", + app_managed: Union[bool, PipelineVariable] = False, + feature_store_output: Optional["FeatureStoreOutput"] = None, ): """Initializes a ``ProcessingOutput`` instance. @@ -1277,24 +1281,24 @@ class FrameworkProcessor(ScriptProcessor): # Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults. def __init__( self, - estimator_cls, - framework_version, - role, - instance_count, - instance_type, - py_version="py3", - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + estimator_cls: type, + framework_version: str, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``FrameworkProcessor`` instance. @@ -1486,18 +1490,18 @@ def get_run_args( def run( # type: ignore[override] self, - code, - source_dir=None, - dependencies=None, - git_config=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + code: str, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 36fb86a90b..7bd2f09063 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -13,10 +13,13 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional, List, Dict + from botocore import exceptions from sagemaker.job import _Job from sagemaker.session import Session +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable from sagemaker.utils import base_name_from_image, name_from_base @@ -27,21 +30,21 @@ class Transformer(object): def __init__( self, - model_name, - instance_count, - instance_type, - strategy=None, - assemble_with=None, - output_path=None, - output_kms_key=None, - accept=None, - max_concurrent_transforms=None, - max_payload=None, - tags=None, - env=None, - base_transform_job_name=None, - sagemaker_session=None, - volume_kms_key=None, + model_name: Union[str, PipelineVariable], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + strategy: Optional[Union[str, PipelineVariable]] = None, + assemble_with: Optional[Union[str, PipelineVariable]] = None, + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + accept: Optional[Union[str, PipelineVariable]] = None, + max_concurrent_transforms: Optional[Union[int, PipelineVariable]] = None, + max_payload: Optional[Union[int, PipelineVariable]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + base_transform_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``Transformer``. @@ -111,19 +114,19 @@ def __init__( @runnable_by_pipeline def transform( self, - data, - data_type="S3Prefix", - content_type=None, - compression_type=None, - split_type=None, - job_name=None, - input_filter=None, - output_filter=None, - join_source=None, - experiment_config=None, - model_client_config=None, - wait=True, - logs=True, + data: Union[str, PipelineVariable], + data_type: Union[str, PipelineVariable] = "S3Prefix", + content_type: Optional[Union[str, PipelineVariable]] = None, + compression_type: Optional[Union[str, PipelineVariable]] = None, + split_type: Optional[Union[str, PipelineVariable]] = None, + job_name: Optional[str] = None, + input_filter: Optional[Union[str, PipelineVariable]] = None, + output_filter: Optional[Union[str, PipelineVariable]] = None, + join_source: Optional[Union[str, PipelineVariable]] = None, + experiment_config: Optional[Dict[str, str]] = None, + model_client_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, ): """Start a new transform job. diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index f6229172c8..76337b8b4f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -19,6 +19,7 @@ import logging from enum import Enum +from typing import Union, Dict, Optional, List, Set import sagemaker from sagemaker.amazon.amazon_estimator import ( @@ -29,8 +30,8 @@ from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework -from sagemaker.inputs import TrainingInput +from sagemaker.estimator import Framework, EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model from sagemaker.parameter import ( @@ -39,6 +40,7 @@ IntegerParameter, ParameterRange, ) +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.session import Session @@ -95,7 +97,11 @@ class WarmStartConfig(object): {"p1","p2"} """ - def __init__(self, warm_start_type, parents): + def __init__( + self, + warm_start_type: WarmStartTypes, + parents: Set[Union[str, PipelineVariable]], + ): """Creates a ``WarmStartConfig`` with provided ``WarmStartTypes`` and parents. Args: @@ -208,19 +214,19 @@ class HyperparameterTuner(object): def __init__( self, - estimator, - objective_metric_name, - hyperparameter_ranges, - metric_definitions=None, - strategy="Bayesian", - objective_type="Maximize", - max_jobs=1, - max_parallel_jobs=1, - tags=None, - base_tuning_job_name=None, - warm_start_config=None, - early_stopping_type="Off", - estimator_name=None, + estimator: EstimatorBase, + objective_metric_name: Union[str, PipelineVariable], + hyperparameter_ranges: Dict[str, ParameterRange], + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + strategy: Union[str, PipelineVariable] = "Bayesian", + objective_type: Union[str, PipelineVariable] = "Maximize", + max_jobs: Union[int, PipelineVariable] = 1, + max_parallel_jobs: Union[int, PipelineVariable] = 1, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + base_tuning_job_name: Optional[str] = None, + warm_start_config: Optional[WarmStartConfig] = None, + early_stopping_type: Union[str, PipelineVariable] = "Off", + estimator_name: Optional[str] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -427,11 +433,13 @@ def _prepare_static_hyperparameters( @runnable_by_pipeline def fit( self, - inputs=None, - job_name=None, - include_cls_metadata=False, - estimator_kwargs=None, - wait=True, + inputs: Optional[ + Union[str, Dict, List, TrainingInput, FileSystemInput, RecordSet, FileSystemRecordSet] + ] = None, + job_name: Optional[str] = None, + include_cls_metadata: Union[bool, Dict[str, bool]] = False, + estimator_kwargs: Optional[Dict[str, dict]] = None, + wait: bool = True, **kwargs ): """Start a hyperparameter tuning job. diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 0c6a6e34df..397e65f867 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -24,7 +24,7 @@ from sagemaker.transformer import Transformer from sagemaker.tuner import HyperparameterTuner from sagemaker.workflow.pipeline_context import PipelineSession -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.steps import TrainingStep from sagemaker.workflow.pipeline import Pipeline, PipelineGraph @@ -203,6 +203,8 @@ def hyperparameters(): def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters): custom_step1 = CustomStep("TestStep") custom_step2 = CustomStep("SecondTestStep") + enable_network_isolation = ParameterBoolean(name="enable_network_isolation") + encrypt_container_traffic = ParameterBoolean(name="encrypt_container_traffic") estimator = Estimator( role=ROLE, instance_count=1, @@ -210,6 +212,8 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar sagemaker_session=pipeline_session, image_uri=IMAGE_URI, hyperparameters=hyperparameters, + enable_network_isolation=enable_network_isolation, + encrypt_inter_container_traffic=encrypt_container_traffic, ) with warnings.catch_warnings(record=True) as w: @@ -231,8 +235,13 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar pipeline = Pipeline( name="MyPipeline", steps=[step, custom_step1, custom_step2], + parameters=[enable_network_isolation, encrypt_container_traffic], sagemaker_session=pipeline_session, ) + step_args.args["EnableInterContainerTrafficEncryption"] = { + "Get": "Parameters.encrypt_container_traffic" + } + step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"} assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Description": "TrainingStep description", From 9c1952460b070e7f7eb5c99255a64d6bbc16689b Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh <105655261+rahven14@users.noreply.github.com> Date: Thu, 7 Jul 2022 22:36:10 +0530 Subject: [PATCH 106/526] feature: include fields to work with inference recommender (#3174) --- src/sagemaker/estimator.py | 24 ++++++++ src/sagemaker/huggingface/model.py | 24 ++++++++ src/sagemaker/model.py | 39 ++++++++++++- src/sagemaker/mxnet/model.py | 24 ++++++++ src/sagemaker/pipeline.py | 32 ++++++++++- src/sagemaker/pytorch/model.py | 24 ++++++++ src/sagemaker/session.py | 30 ++++++++++ src/sagemaker/sklearn/model.py | 24 ++++++++ src/sagemaker/tensorflow/model.py | 24 ++++++++ src/sagemaker/utils.py | 56 +++++++++++++++++++ src/sagemaker/workflow/_utils.py | 11 ++++ src/sagemaker/workflow/step_collections.py | 29 ++++++++++ .../test_model_create_and_registration.py | 54 ++++++++++++++++++ .../workflow/test_pipeline_session.py | 21 ++++++- .../workflow/test_step_collections.py | 44 +++++++++++++++ tests/unit/test_estimator.py | 42 ++++++++++++++ tests/unit/test_session.py | 12 ++++ 17 files changed, 510 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index f31cfd938d..c867f0b199 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1310,6 +1310,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -1343,6 +1349,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: Passed to invocation of ``create_model()``. Implementations may customize ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy. For more, see the implementation docs. @@ -1380,6 +1398,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) @property diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index b72f1b1af2..8814b72175 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -306,6 +306,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -337,6 +343,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -367,6 +385,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index bfa4caa6e0..60c766379b 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -35,7 +35,10 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model -from sagemaker.utils import unique_name_from_base +from sagemaker.utils import ( + unique_name_from_base, + update_container_with_inference_params, +) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable @@ -310,6 +313,12 @@ def register( customer_metadata_properties=None, validation_specification=None, domain=None, + task=None, + sample_payload_url=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -339,6 +348,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance or pipeline step arguments @@ -349,10 +370,22 @@ def register( raise ValueError("SageMaker Model Package cannot be created without model data.") if image_uri is not None: self.image_uri = image_uri + if model_package_group_name is not None: container_def = self.prepare_container_def() + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_obj=container_def, + ) else: - container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data} + container_def = { + "Image": self.image_uri, + "ModelDataUrl": self.model_data, + } + model_pkg_args = sagemaker.get_model_package_args( content_types, response_types, @@ -370,6 +403,8 @@ def register( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) model_package = self.sagemaker_session.create_model_package_from_containers( **model_pkg_args diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index fa2773bebb..60fc1d60d2 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -159,6 +159,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -188,6 +194,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -218,6 +236,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 75fae3bfc4..8cdb82ffe7 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -20,7 +20,10 @@ from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties from sagemaker.session import Session -from sagemaker.utils import name_from_image +from sagemaker.utils import ( + name_from_image, + update_container_with_inference_params, +) from sagemaker.transformer import Transformer from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -279,6 +282,12 @@ def register( drift_check_baselines: Optional[DriftCheckBaselines] = None, customer_metadata_properties: Optional[Dict[str, str]] = None, domain: Optional[str] = None, + sample_payload_url: Optional[str] = None, + task: Optional[str] = None, + framework: Optional[str] = None, + framework_version: Optional[str] = None, + nearest_model_name: Optional[str] = None, + data_input_configuration: Optional[str] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -308,6 +317,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -319,6 +340,13 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=container_def, + ) else: container_def = [ { @@ -344,6 +372,8 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) self.sagemaker_session.create_model_package_from_containers(**model_pkg_args) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 6e5d63c14d..b5e019f492 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -160,6 +160,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -189,6 +195,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -219,6 +237,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 461dfd8bab..eb158eab3d 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2822,6 +2822,8 @@ def create_model_package_from_containers( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -2851,6 +2853,11 @@ def create_model_package_from_containers( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ model_pkg_request = get_create_model_package_request( @@ -2870,6 +2877,8 @@ def create_model_package_from_containers( customer_metadata_properties=customer_metadata_properties, validation_specification=validation_specification, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) def submit(request): @@ -4241,6 +4250,8 @@ def get_model_package_args( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get arguments for create_model_package method. @@ -4273,6 +4284,11 @@ def get_model_package_args( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + Returns: dict: A dictionary of method argument names and values. """ @@ -4316,6 +4332,10 @@ def get_model_package_args( model_package_args["validation_specification"] = validation_specification if domain is not None: model_package_args["domain"] = domain + if sample_payload_url is not None: + model_package_args["sample_payload_url"] = sample_payload_url + if task is not None: + model_package_args["task"] = task return model_package_args @@ -4337,6 +4357,8 @@ def get_create_model_package_request( customer_metadata_properties=None, validation_specification=None, domain=None, + sample_payload_url=None, + task=None, ): """Get request dictionary for CreateModelPackage API. @@ -4367,6 +4389,10 @@ def get_create_model_package_request( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). """ if all([model_package_name, model_package_group_name]): @@ -4394,6 +4420,10 @@ def get_create_model_package_request( request_dict["ValidationSpecification"] = validation_specification if domain is not None: request_dict["Domain"] = domain + if sample_payload_url is not None: + request_dict["SamplePayloadUrl"] = sample_payload_url + if task is not None: + request_dict["Task"] = task if containers is not None: if not all([content_types, response_types]): raise ValueError( diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 71fe048bf1..67f9d60175 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -154,6 +154,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -183,6 +189,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -213,6 +231,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def prepare_container_def( diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 3c4bf3343a..e5e6798a63 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -206,6 +206,12 @@ def register( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -235,6 +241,18 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). Returns: A `sagemaker.model.ModelPackage` instance. @@ -265,6 +283,12 @@ def register( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) def deploy( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1d2e9fe5cb..ed5b3c5e75 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -722,3 +722,59 @@ def get_data_bucket(self, region_requested=None): get_ecr_image_uri_prefix = deprecations.removed_function("get_ecr_image_uri_prefix") + + +def update_container_with_inference_params( + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, + container_obj=None, + container_list=None, +): + """Function to check if inference recommender parameters exist and update container. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + container_obj (dict): object to be updated. + container_list (list): list to be updated. + + Returns: + dict: dict with inference recommender params + """ + + if ( + framework is not None + and framework_version is not None + and nearest_model_name is not None + and data_input_configuration is not None + ): + if container_list is not None: + for obj in container_list: + obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) + if container_obj is not None: + container_obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 7b8a3cdc25..7a0a399299 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -285,6 +285,8 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, **kwargs, ): """Constructor of a register model step. @@ -329,6 +331,11 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). **kwargs: additional arguments to `create_model`. """ super(_RegisterModelStep, self).__init__( @@ -360,6 +367,8 @@ def __init__( self.drift_check_baselines = drift_check_baselines self.customer_metadata_properties = customer_metadata_properties self.domain = domain + self.sample_payload_url = sample_payload_url + self.task = task self.metadata_properties = metadata_properties self.approval_status = approval_status self.image_uri = image_uri @@ -438,6 +447,8 @@ def arguments(self) -> RequestType: container_def_list=self.container_def_list, customer_metadata_properties=self.customer_metadata_properties, domain=self.domain, + sample_payload_url=self.sample_payload_url, + task=self.task, ) request_dict = get_create_model_package_request(**model_package_args) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index d52ddace87..dd9529916e 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -27,6 +27,7 @@ from sagemaker.workflow.steps import Step, CreateModelStep, TransformStep from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep from sagemaker.workflow.retry import RetryPolicy +from sagemaker.utils import update_container_with_inference_params @attr.s @@ -80,6 +81,12 @@ def __init__( drift_check_baselines=None, customer_metadata_properties=None, domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, **kwargs, ): """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator. @@ -123,6 +130,18 @@ def __init__( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). **kwargs: additional arguments to `create_model`. """ @@ -228,6 +247,14 @@ def __init__( ) ] + update_container_with_inference_params( + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + container_list=self.container_def_list, + ) + register_model_step = _RegisterModelStep( name=name, estimator=estimator, @@ -249,6 +276,8 @@ def __init__( retry_policies=register_model_step_retry_policies, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, **kwargs, ) if not repack_model: diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index d8f1d9ab6c..d0f617a266 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -94,6 +94,13 @@ def test_conditional_pytorch_training_model_registration( good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1) in_condition_input = ParameterString(name="Foo", default_value="Foo") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) pytorch_estimator = PyTorch( @@ -120,6 +127,12 @@ def test_conditional_pytorch_training_model_registration( inference_instances=["*"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) model = Model( @@ -201,6 +214,13 @@ def test_mxnet_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + model = MXNetModel( entry_point=entry_point, source_dir=source_dir, @@ -219,6 +239,12 @@ def test_mxnet_model_registration( inference_instances=["ml.m5.xlarge"], transform_instances=["*"], description="test-description", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -262,6 +288,13 @@ def test_sklearn_xgboost_sip_model_registration( instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = "ml.m5.xlarge" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' + # The instance_type should not be a pipeline variable # since it is used to retrieve image_uri in compile time (PySDK) sklearn_processor = SKLearnProcessor( @@ -412,6 +445,12 @@ def test_sklearn_xgboost_sip_model_registration( inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="windturbine", + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -555,6 +594,12 @@ def test_model_registration_with_drift_check_baselines( ) customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_configuration = '{"input_1":[1,224,224,3]}' # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -568,6 +613,7 @@ def test_model_registration_with_drift_check_baselines( py_version="py3", role=role, ) + step_register = RegisterModel( name="MyRegisterModelStep", estimator=estimator, @@ -581,6 +627,12 @@ def test_model_registration_with_drift_check_baselines( drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -652,6 +704,8 @@ def test_model_registration_with_drift_check_baselines( ) assert response["CustomerMetadataProperties"] == customer_metadata_properties assert response["Domain"] == domain + assert response["Task"] == task + assert response["SamplePayloadUrl"] == sample_payload_url break finally: try: diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index d2954ede7b..90a9116c07 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -116,6 +116,12 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock): inference_instances=["ml.t2.medium", "ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) # The context should be cleaned up before return assert not pipeline_session_mock.context @@ -136,11 +142,16 @@ def test_pipeline_session_context_for_model_step_without_instance_types( source_dir=f"{DATA_DIR}", role=_ROLE, ) - register_step_args = model.register( content_types=["text/csv"], response_types=["text/csv"], model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) expected_output = { @@ -159,6 +170,12 @@ def test_pipeline_session_context_for_model_step_without_instance_types( name="ModelData", default_value="s3://my-bucket/file", ), + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, } ], "SupportedContentTypes": ["text/csv"], @@ -168,6 +185,8 @@ def test_pipeline_session_context_for_model_step_without_instance_types( }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } assert register_step_args.create_model_package_request == expected_output diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 9d41e70aca..4aa55fd068 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -368,6 +368,12 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): display_name="RegisterModelStep", depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -412,6 +418,8 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -433,6 +441,12 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): drift_check_baselines=drift_check_baselines, approval_status="Approved", description="description", + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -474,6 +488,8 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -502,6 +518,12 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): description="description", model=pipeline_model, depends_on=["TestStep"], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -517,11 +539,23 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Image": "fakeimage1", "ModelDataUrl": "Url1", "Environment": [{"k1": "v1"}, {"k2": "v2"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, { "Image": "fakeimage2", "ModelDataUrl": "Url2", "Environment": [{"k3": "v3"}, {"k4": "v4"}], + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, ], "SupportedContentTypes": ["content_type"], @@ -550,6 +584,8 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): }, "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", }, }, ] @@ -578,6 +614,12 @@ def test_register_model_with_model_repack_with_estimator( dependencies=[dummy_requirements], depends_on=["TestStep"], tags=[{"Key": "myKey", "Value": "myValue"}], + sample_payload_url="s3://test-bucket/model", + task="IMAGE_CLASSIFICATION", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', ) request_dicts = register_model.request_dicts() @@ -680,6 +722,8 @@ def test_register_model_with_model_repack_with_estimator( "ModelPackageDescription": "description", "ModelPackageGroupName": "mpg", "Tags": [{"Key": "myKey", "Value": "myValue"}], + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", } ) else: diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 7e2e985845..78298025ea 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3109,6 +3109,12 @@ def test_register_default_image(sagemaker_session): response_types = ["application/json"] inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3116,6 +3122,12 @@ def test_register_default_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3132,6 +3144,8 @@ def test_register_default_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3153,11 +3167,23 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): model_package_name = "test-estimator-register-model" content_types = ["application/json"] response_types = ["application/json"] + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, response_types=response_types, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3174,6 +3200,8 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): "transform_instances": None, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request @@ -3198,6 +3226,12 @@ def test_register_inference_image(sagemaker_session): inference_instances = ["ml.m4.xlarge"] transform_instances = ["ml.m4.xlarget"] inference_image = "fake-inference-image" + sample_payload_url = "s3://test-bucket/model" + task = "IMAGE_CLASSIFICATION" + framework = "TENSORFLOW" + framework_version = "2.9" + nearest_model_name = "resnet50" + data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3205,7 +3239,13 @@ def test_register_inference_image(sagemaker_session): inference_instances=inference_instances, transform_instances=transform_instances, model_package_name=model_package_name, + sample_payload_url=sample_payload_url, + task=task, image_uri=inference_image, + framework=framework, + framework_version=framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3222,6 +3262,8 @@ def test_register_inference_image(sagemaker_session): "transform_instances": transform_instances, "model_package_name": model_package_name, "marketplace_cert": False, + "sample_payload_url": sample_payload_url, + "task": task, } sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2040ed5d80..a02ea6eeca 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1705,6 +1705,12 @@ def test_create_model_with_both(expand_container_def, sagemaker_session): "Environment": {"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "application/json"}, "Image": "mi-1", "ModelDataUrl": "s3://bucket/model_1.tar.gz", + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, }, {"Environment": {}, "Image": "mi-2", "ModelDataUrl": "s3://bucket/model_2.tar.gz"}, ] @@ -2387,6 +2393,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): description = "description" customer_metadata_properties = {"key1": "value1"} domain = "COMPUTER_VISION" + task = "IMAGE_CLASSIFICATION" + sample_payload_url = "s3://test-bucket/model" sagemaker_session.create_model_package_from_containers( containers=containers, content_types=content_types, @@ -2402,6 +2410,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): drift_check_baselines=drift_check_baselines, customer_metadata_properties=customer_metadata_properties, domain=domain, + sample_payload_url=sample_payload_url, + task=task, ) expected_args = { "ModelPackageName": model_package_name, @@ -2420,6 +2430,8 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): "DriftCheckBaselines": drift_check_baselines, "CustomerMetadataProperties": customer_metadata_properties, "Domain": domain, + "SamplePayloadUrl": sample_payload_url, + "Task": task, } sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) From 587fae1951d63f8af5a363df44881910a5fb5a09 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 6 Jul 2022 10:25:03 -0700 Subject: [PATCH 107/526] feature: support heterogeneous cluster for training Co-authored-by: Navin Soni --- src/sagemaker/algorithm.py | 26 ++++++++--------- src/sagemaker/estimator.py | 26 ++++++++++++----- src/sagemaker/inputs.py | 5 ++++ src/sagemaker/instance_group.py | 50 +++++++++++++++++++++++++++++++ src/sagemaker/job.py | 23 +++++++++++++-- tests/unit/test_estimator.py | 26 +++++++++++++++++ tests/unit/test_inputs.py | 38 ++++++++++++++++++++++++ tests/unit/test_job.py | 52 +++++++++++++++++++++++++++++++-- 8 files changed, 220 insertions(+), 26 deletions(-) create mode 100644 src/sagemaker/instance_group.py diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index a55635b1c3..1ab5ee3bcf 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -147,19 +147,19 @@ def __init__( self.algorithm_arn = algorithm_arn super(AlgorithmEstimator, self).__init__( role, - instance_count, - instance_type, - volume_size, - volume_kms_key, - max_run, - input_mode, - output_path, - output_kms_key, - base_job_name, - sagemaker_session, - tags, - subnets, - security_group_ids, + instance_count=instance_count, + instance_type=instance_type, + volume_size=volume_size, + volume_kms_key=volume_kms_key, + max_run=max_run, + input_mode=input_mode, + output_path=output_path, + output_kms_key=output_kms_key, + base_job_name=base_job_name, + sagemaker_session=sagemaker_session, + tags=tags, + subnets=subnets, + security_group_ids=security_group_ids, model_uri=model_uri, model_channel_name=model_channel_name, metric_definitions=metric_definitions, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c867f0b199..bd274b01b0 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -145,6 +145,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[str] = None, dependencies: Optional[List[Union[str]]] = None, + instance_groups=None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -156,9 +157,10 @@ def __init__( artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_count (int): Number of Amazon EC2 instances to use - for training. + for training. Required if instance_groups is not set. instance_type (str): Type of EC2 instance to use for training, - for example, 'ml.c4.xlarge'. + for example, 'ml.c4.xlarge'. Required if instance_groups is + not set. volume_size (int): Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the @@ -424,7 +426,10 @@ def __init__( >>> |------ virtual-env This is not supported with "local code" in Local Mode. - + instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup + for specifying different instance groups for heterogeneous cluster. + For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64), + sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)] """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -442,12 +447,10 @@ def __init__( "train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs ) - if instance_count is None or instance_type is None: - raise ValueError("Both instance_count and instance_type are required.") - self.role = role self.instance_count = instance_count self.instance_type = instance_type + self.instance_groups = instance_groups self.volume_size = volume_size self.volume_kms_key = volume_kms_key self.max_run = max_run @@ -2103,6 +2106,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[str] = None, dependencies: Optional[List[str]] = None, + instance_groups=None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2115,9 +2119,10 @@ def __init__( artifacts. After the endpoint is created, the inference code might use the IAM role, if it needs to access an AWS resource. instance_count (int): Number of Amazon EC2 instances to use - for training. + for training. Required if instance_groups is not set. instance_type (str): Type of EC2 instance to use for training, - for example, 'ml.c4.xlarge'. + for example, 'ml.c4.xlarge'. Required if instance_groups is + not set. volume_size (int): Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large enough to store training data if File Mode is used (which is the @@ -2379,6 +2384,10 @@ def __init__( >>> |------ virtual-env This is not supported with "local code" in Local Mode. + instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup + for specifying different instance groups for heterogeneous cluster. + For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64), + sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)] """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -2386,6 +2395,7 @@ def __init__( role, instance_count, instance_type, + instance_groups, volume_size, volume_kms_key, max_run, diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 6704dee29a..d2bd94d232 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -35,6 +35,7 @@ def __init__( content_type=None, record_wrapping=None, s3_data_type="S3Prefix", + instance_groups=None, input_mode=None, attribute_names=None, target_attribute_name=None, @@ -60,6 +61,8 @@ def __init__( listing the S3 data to train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html + instance_groups (list[str]): Optional. List of InstanceGroupNames to send data to + (default: None). By default, data will be sent to all groups. input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore @@ -97,6 +100,8 @@ def __init__( self.config["ContentType"] = content_type if record_wrapping is not None: self.config["RecordWrapperType"] = record_wrapping + if instance_groups is not None: + self.config["DataSource"]["S3DataSource"]["InstanceGroupNames"] = instance_groups if input_mode is not None: self.config["InputMode"] = input_mode if attribute_names is not None: diff --git a/src/sagemaker/instance_group.py b/src/sagemaker/instance_group.py new file mode 100644 index 0000000000..005a39433b --- /dev/null +++ b/src/sagemaker/instance_group.py @@ -0,0 +1,50 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This file defines instance group for heterogeneous cluster.""" +from __future__ import absolute_import + + +class InstanceGroup(object): + """Accepts instance group parameters for conversion to request dict. + + The `_to_request_dict` provides a method to turn the parameters into a dict. + """ + + def __init__( + self, + instance_group_name=None, + instance_type=None, + instance_count=None, + ): + """Initialize a ``InstanceGroup`` instance. + + InstanceGroup accepts instance group parameters and provides a method to turn + these parameters into a dictionary. + + Args: + instance_group_name (str): Name of the instance group. + instance_type (str): Type of EC2 instance to use in the instance group, + for example, 'ml.c4.xlarge'. + instance_count (int): Number of EC2 instances to use in the instance group. + """ + self.instance_group_name = instance_group_name + self.instance_type = instance_type + self.instance_count = instance_count + + def _to_request_dict(self): + """Generates a request dictionary using the parameters provided to the class.""" + return { + "InstanceGroupName": self.instance_group_name, + "InstanceType": self.instance_type, + "InstanceCount": self.instance_count, + } diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 2c4ece3a25..1b9d46cd15 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -74,6 +74,7 @@ def _load_config(inputs, estimator, expand_role=True, validate_uri=True): resource_config = _Job._prepare_resource_config( estimator.instance_count, estimator.instance_type, + estimator.instance_groups, estimator.volume_size, estimator.volume_kms_key, ) @@ -283,15 +284,31 @@ def _prepare_output_config(s3_path, kms_key_id): return config @staticmethod - def _prepare_resource_config(instance_count, instance_type, volume_size, volume_kms_key): + def _prepare_resource_config( + instance_count, instance_type, instance_groups, volume_size, volume_kms_key + ): """Placeholder docstring""" resource_config = { - "InstanceCount": instance_count, - "InstanceType": instance_type, "VolumeSizeInGB": volume_size, } if volume_kms_key is not None: resource_config["VolumeKmsKeyId"] = volume_kms_key + if instance_groups is not None: + if instance_count is not None or instance_type is not None: + raise ValueError( + "instance_count and instance_type cannot be set when instance_groups is set" + ) + + resource_config["InstanceGroups"] = [ + group._to_request_dict() for group in instance_groups + ] + else: + if instance_count is None or instance_type is None: + raise ValueError( + "instance_count and instance_type must be set if instance_groups is not set" + ) + resource_config["InstanceCount"] = instance_count + resource_config["InstanceType"] = instance_type return resource_config diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 78298025ea..30b80bd58b 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -43,6 +43,7 @@ from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob from sagemaker.fw_utils import PROFILER_UNSUPPORTED_REGIONS from sagemaker.inputs import ShuffleConfig +from sagemaker.instance_group import InstanceGroup from sagemaker.model import FrameworkModel from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor @@ -323,6 +324,31 @@ def test_framework_all_init_args(sagemaker_session): } +def test_framework_with_heterogeneous_cluster(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args["resource_config"]["InstanceGroups"][0] == { + "InstanceGroupName": "group1", + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + } + assert args["resource_config"]["InstanceGroups"][1] == { + "InstanceGroupName": "group2", + "InstanceCount": 2, + "InstanceType": "ml.m4.xlarge", + } + + def test_framework_with_debugger_and_built_in_rule(sagemaker_session): debugger_built_in_rule_with_custom_args = Rule.sagemaker( base_config=rule_configs.stalled_training_rule(), diff --git a/tests/unit/test_inputs.py b/tests/unit/test_inputs.py index fc361ff52b..7d9c2b2c2f 100644 --- a/tests/unit/test_inputs.py +++ b/tests/unit/test_inputs.py @@ -67,6 +67,44 @@ def test_training_input_all_arguments(): assert result.config == expected +def test_training_input_all_arguments_heterogeneous_cluster(): + prefix = "pre" + distribution = "FullyReplicated" + compression = "Gzip" + content_type = "text/csv" + record_wrapping = "RecordIO" + s3_data_type = "Manifestfile" + instance_groups = ["data-server"] + input_mode = "Pipe" + result = TrainingInput( + s3_data=prefix, + distribution=distribution, + compression=compression, + input_mode=input_mode, + content_type=content_type, + record_wrapping=record_wrapping, + s3_data_type=s3_data_type, + instance_groups=instance_groups, + ) + + expected = { + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": distribution, + "S3DataType": s3_data_type, + "S3Uri": prefix, + "InstanceGroupNames": instance_groups, + } + }, + "CompressionType": compression, + "ContentType": content_type, + "RecordWrapperType": record_wrapping, + "InputMode": input_mode, + } + + assert result.config == expected + + def test_file_system_input_default_access_mode(): file_system_id = "fs-0a48d2a1" file_system_type = "EFS" diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 7d4fd371a9..3c79b407d8 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -20,6 +20,7 @@ from sagemaker.amazon.amazon_estimator import RecordSet, FileSystemRecordSet from sagemaker.estimator import Estimator, Framework from sagemaker.inputs import FileSystemInput +from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.model import FrameworkModel @@ -28,6 +29,7 @@ LOCAL_FILE_NAME = "file://local/file" INSTANCE_COUNT = 1 INSTANCE_TYPE = "c4.4xlarge" +INSTANCE_GROUP = InstanceGroup("group", "ml.c4.xlarge", 1) VOLUME_SIZE = 1 MAX_RUNTIME = 1 ROLE = "DummyRole" @@ -597,7 +599,7 @@ def test_prepare_output_config_kms_key_none(): def test_prepare_resource_config(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, None + INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, None ) assert resource_config == { @@ -609,7 +611,7 @@ def test_prepare_resource_config(): def test_prepare_resource_config_with_volume_kms(): resource_config = _Job._prepare_resource_config( - INSTANCE_COUNT, INSTANCE_TYPE, VOLUME_SIZE, VOLUME_KMS_KEY + INSTANCE_COUNT, INSTANCE_TYPE, None, VOLUME_SIZE, VOLUME_KMS_KEY ) assert resource_config == { @@ -620,6 +622,52 @@ def test_prepare_resource_config_with_volume_kms(): } +def test_prepare_resource_config_with_heterogeneous_cluster(): + resource_config = _Job._prepare_resource_config( + None, + None, + [InstanceGroup("group1", "ml.c4.xlarge", 1), InstanceGroup("group2", "ml.m4.xlarge", 2)], + VOLUME_SIZE, + None, + ) + + assert resource_config == { + "InstanceGroups": [ + {"InstanceGroupName": "group1", "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, + {"InstanceGroupName": "group2", "InstanceCount": 2, "InstanceType": "ml.m4.xlarge"}, + ], + "VolumeSizeInGB": VOLUME_SIZE, + } + + +def test_prepare_resource_config_with_instance_groups_instance_type_instance_count_set(): + with pytest.raises(ValueError) as error: + _Job._prepare_resource_config( + INSTANCE_COUNT, + INSTANCE_TYPE, + [INSTANCE_GROUP], + VOLUME_SIZE, + None, + ) + assert "instance_count and instance_type cannot be set when instance_groups is set" in str( + error + ) + + +def test_prepare_resource_config_with_instance_groups_instance_type_instance_count_not_set(): + with pytest.raises(ValueError) as error: + _Job._prepare_resource_config( + None, + None, + None, + VOLUME_SIZE, + None, + ) + assert "instance_count and instance_type must be set if instance_groups is not set" in str( + error + ) + + def test_prepare_stop_condition(): max_run = 1 max_wait = 2 From 26bf3299f811ef58692419ae3d04bc26efaf5730 Mon Sep 17 00:00:00 2001 From: Allen Liu Date: Tue, 28 Jun 2022 15:39:08 -0700 Subject: [PATCH 108/526] feature: heterogeneous cluster set up in distribution config Co-authored-by: Jessica Zhu <106775307+jessicazhu3@users.noreply.github.com> --- src/sagemaker/estimator.py | 14 ++ src/sagemaker/fw_utils.py | 102 ++++++++++++++- src/sagemaker/pytorch/estimator.py | 33 ++--- src/sagemaker/tensorflow/estimator.py | 23 ++-- .../sagemaker/tensorflow/test_estimator.py | 22 ++++ tests/unit/test_fw_utils.py | 120 ++++++++++++++++++ tests/unit/test_pytorch.py | 20 ++- 7 files changed, 298 insertions(+), 36 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index bd274b01b0..937fd132b4 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -3111,6 +3111,13 @@ def _distribution_configuration(self, distribution): """ distribution_config = {} + mpi_enabled = False + smdataparallel_enabled = False + if "instance_groups" in distribution: + distribution_config["sagemaker_distribution_instance_groups"] = distribution[ + "instance_groups" + ] + if "parameter_server" in distribution: ps_enabled = distribution.get("parameter_server").get("enabled", False) distribution_config[self.LAUNCH_PS_ENV_NAME] = ps_enabled @@ -3146,6 +3153,13 @@ def _distribution_configuration(self, distribution): "dataparallel" ].get("custom_mpi_options", "") + if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get( + "sagemaker_distribution_instance_groups" + ) not in [None, []]: + raise ValueError( + "Don't set training instance groups while no distribution strategies enabled!" + ) + return distribution_config diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f3a44fbac0..2fcb5d19f7 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -27,7 +27,7 @@ import sagemaker.utils from sagemaker.workflow import is_pipeline_variable -from sagemaker.deprecations import renamed_warning +from sagemaker.deprecations import renamed_warning, renamed_kwargs logger = logging.getLogger(__name__) @@ -600,6 +600,106 @@ def _validate_smdataparallel_args( raise ValueError(err_msg) +def validate_distribution( + distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs +): + """Check if distribution strategy is correctly invoked by the user. + + Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up. + Validate if the user requested strategy is supported. + + Args: + distribution (dict): A dictionary with information to enable distributed training. + (Defaults to None if distributed training is not enabled.) For example: + + .. code:: python + + { + "smdistributed": { + "dataparallel": { + "enabled": True + } + } + } + instance_groups ([InstanceGroup]): A list contains instance groups used for training. + framework_name (str): A string representing the name of framework selected. + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + kwargs(dict): Additional kwargs passed to this function + + Returns: + distribution(dict): updated dictionary with validated information + to enable distributed training. + + Raises: + ValueError: if distribution dictionary isn't correctly formatted or + multiple strategies are requested simultaneously or + an unsupported strategy is requested or + strategy-specific inputs are incorrect/unsupported or + heterogeneous cluster set up is incorrect + """ + train_instance_groups = distribution.get("instance_groups", []) + if instance_groups is None: + if len(train_instance_groups) >= 1: + # if estimator's instance_groups is not defined but + # train_instance_groups are specified in distribution + raise ValueError("Instance groups not specified in the estimator !") + else: + if len(train_instance_groups) > len(instance_groups): + # if train_instance_groups in distribution are more than estimator's instance_groups + raise ValueError("Train instance groups oversubscribed !") + if len(instance_groups) == 1 and len(train_instance_groups) == 0: + # if just one instance_group but it is not specified in distribution, we set it for user + train_instance_groups = instance_groups + elif len(instance_groups) > 1 and len(train_instance_groups) != 1: + # currently we just support one train instance group + raise ValueError("Distribution should only contain one instance group name !") + + if len(train_instance_groups) != 0: + # in this case, we are handling a heterogeneous cluster training job + instance_group_names = [] + for train_instance_group in train_instance_groups: + # in future version we will support multiple train_instance_groups, so use loop here + if train_instance_group not in instance_groups: + # check if train instance groups belongs to what user defined in estimator set up + raise ValueError( + f"Invalid training instance group {train_instance_group.instance_group_name} !" + ) + instance_type = train_instance_group.instance_type + validate_smdistributed( + instance_type=instance_type, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + distribution=distribution, + image_uri=image_uri, + ) + warn_if_parameter_server_with_multi_gpu( + training_instance_type=instance_type, distribution=distribution + ) + # get instance group names + instance_group_names.append(train_instance_group.instance_group_name) + distribution["instance_groups"] = instance_group_names + else: + # in this case, we are handling a normal training job (without heterogeneous cluster) + instance_type = renamed_kwargs( + "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs + ) + validate_smdistributed( + instance_type=instance_type, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + distribution=distribution, + image_uri=image_uri, + ) + warn_if_parameter_server_with_multi_gpu( + training_instance_type=instance_type, distribution=distribution + ) + return distribution + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index ca1ef3a447..2cd5a0c798 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -17,15 +17,13 @@ from packaging.version import Version -from sagemaker.deprecations import renamed_kwargs from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( framework_name_from_image, framework_version_from_tag, python_deprecation_warning, validate_version_or_image_args, - warn_if_parameter_server_with_multi_gpu, - validate_smdistributed, + validate_distribution, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -196,24 +194,6 @@ def __init__( self.framework_version = framework_version self.py_version = py_version - if distribution is not None: - instance_type = renamed_kwargs( - "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs - ) - - validate_smdistributed( - instance_type=instance_type, - framework_name=self._framework_name, - framework_version=framework_version, - py_version=py_version, - distribution=distribution, - image_uri=image_uri, - ) - - warn_if_parameter_server_with_multi_gpu( - training_instance_type=instance_type, distribution=distribution - ) - if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for PT v1.3 or greater: if self.framework_version and Version(self.framework_version) >= Version("1.3"): @@ -222,6 +202,17 @@ def __init__( super(PyTorch, self).__init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) + if distribution is not None: + distribution = validate_distribution( + distribution, + self.instance_groups, + self._framework_name, + framework_version, + py_version, + image_uri, + kwargs, + ) + self.distribution = distribution or {} def hyperparameters(self): diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index d97e39d313..4db647e140 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -183,25 +183,22 @@ def __init__( self.py_version = py_version self.instance_type = instance_type - if distribution is not None: - fw.warn_if_parameter_server_with_multi_gpu( - training_instance_type=instance_type, distribution=distribution - ) - fw.validate_smdistributed( - instance_type=instance_type, - framework_name=self._framework_name, - framework_version=framework_version, - py_version=py_version, - distribution=distribution, - image_uri=image_uri, - ) - if "enable_sagemaker_metrics" not in kwargs: # enable sagemaker metrics for TF v1.15 or greater: if framework_version and version.Version(framework_version) >= version.Version("1.15"): kwargs["enable_sagemaker_metrics"] = True super(TensorFlow, self).__init__(image_uri=image_uri, **kwargs) + if distribution is not None: + distribution = fw.validate_distribution( + distribution, + self.instance_groups, + self._framework_name, + framework_version, + py_version, + image_uri, + kwargs, + ) self.model_dir = model_dir self.distribution = distribution or {} diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 7b40420fb6..d27359f010 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -22,6 +22,7 @@ from sagemaker.estimator import _TrainingJob from sagemaker.tensorflow import TensorFlow +from sagemaker.instance_group import InstanceGroup from tests.unit import DATA_DIR SCRIPT_FILE = "dummy_script.py" @@ -538,3 +539,24 @@ def test_custom_image(sagemaker_session): custom_image = "tensorflow:latest" tf = _build_tf(sagemaker_session, image_uri=custom_image) assert custom_image == tf.training_image_uri() + + +def test_tf_heterogeneous_cluster_distribution_config( + sagemaker_session, tensorflow_training_version, tensorflow_training_py_version +): + if version.Version(tensorflow_training_version) < version.Version("2.0"): + pytest.skip("This test is for TF 2.0 and higher.") + + training_group = InstanceGroup("train_group", "ml.c4.xlarge", 1) + expected_return = {"mpi": {"enabled": True}, "instance_groups": ["train_group"]} + tf = _build_tf( + sagemaker_session, + framework_version=tensorflow_training_version, + py_version=tensorflow_training_py_version, + instance_groups=[training_group], + distribution={ + "mpi": {"enabled": True}, + "instance_groups": [training_group], + }, + ) + assert tf.distribution == expected_return diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 17dd7de12b..24bb7368a4 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -25,6 +25,7 @@ from sagemaker import fw_utils from sagemaker.utils import name_from_image from sagemaker.session_settings import SessionSettings +from sagemaker.instance_group import InstanceGroup TIMESTAMP = "2017-10-10-14-14-15" @@ -586,6 +587,125 @@ def test_validate_version_or_image_args_raises(): fw_utils.validate_version_or_image_args(framework_version, py_version, image_uri) +def test_validate_distribution_not_raises(): + train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1) + other_group = InstanceGroup("other_group", "ml.p3.16xlarge", 1) + instance_groups = [train_group, other_group] + + smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} + smdataparallel_enabled_custom_mpi = { + "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}} + } + smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}} + mpi_enabled = {"mpi": {"enabled": True, "processes_per_host": 2}} + mpi_disabled = {"mpi": {"enabled": False}} + + instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES) + + good_args_normal = [ + smdataparallel_enabled, + smdataparallel_enabled_custom_mpi, + smdataparallel_disabled, + mpi_enabled, + mpi_disabled, + ] + + frameworks = ["tensorflow", "pytorch"] + + for framework, instance_type in product(frameworks, instance_types): + for distribution in good_args_normal: + fw_utils.validate_distribution( + distribution, + None, # instance_groups + framework, + None, # framework_version + None, # py_version + "custom-container", + {"instance_type": instance_type}, # kwargs + ) + + for framework in frameworks: + good_args_hc = [ + { + "smdistributed": {"dataparallel": {"enabled": True}}, + "instance_groups": [train_group], + }, # smdataparallel_enabled_hc + { + "mpi": {"enabled": True, "processes_per_host": 2}, + "instance_groups": [train_group], + }, # mpi_enabled_hc + { + "smdistributed": { + "dataparallel": {"enabled": True, "custom_mpi_options": "--verbose"}, + }, + "instance_groups": [train_group], + }, # smdataparallel_enabled_custom_mpi_hc + ] + for distribution in good_args_hc: + fw_utils.validate_distribution( + distribution, + instance_groups, # instance_groups + framework, + None, # framework_version + None, # py_version + "custom-container", + {}, # kwargs + ) + + +def test_validate_distribution_raises(): + train_group = InstanceGroup("train_group", "ml.p3.16xlarge", 1) + other_group = InstanceGroup("other_group", "ml.p3.16xlarge", 1) + dummy_group = InstanceGroup("dummy_group", "ml.p3.16xlarge", 1) + instance_groups = [train_group, other_group, dummy_group] + + mpi_enabled_hc = { + "mpi": {"enabled": True, "processes_per_host": 2}, + "instance_groups": [train_group, other_group], + } + smdataparallel_enabled_hc = { + "smdistributed": {"dataparallel": {"enabled": True}}, + "instance_groups": [], + } + + instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES) + + bad_args_normal = [ + {"smdistributed": "dummy"}, + {"smdistributed": {"dummy"}}, + {"smdistributed": {"dummy": "val"}}, + {"smdistributed": {"dummy": {"enabled": True}}}, + ] + bad_args_hc = [mpi_enabled_hc, smdataparallel_enabled_hc] + frameworks = ["tensorflow", "pytorch"] + + for framework, instance_type in product(frameworks, instance_types): + for distribution in bad_args_normal: + with pytest.raises(ValueError): + fw_utils.validate_distribution( + distribution, + None, # instance_groups + framework, + None, # framework_version + None, # py_version + "custom-container", + {"instance_type": instance_type}, # kwargs + ) + + for framework in frameworks: + for distribution in bad_args_hc: + with pytest.raises(ValueError): + fw_utils.validate_distribution( + distribution, + instance_groups, # instance_groups + framework, + None, # framework_version + None, # py_version + "custom-container", + {}, # kwargs + ) + + def test_validate_smdistributed_not_raises(): smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} smdataparallel_enabled_custom_mpi = { diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 5e5046fd6f..e39abf01fd 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -22,7 +22,7 @@ from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel - +from sagemaker.instance_group import InstanceGroup DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -691,3 +691,21 @@ def test_custom_image_estimator_deploy( pytorch.fit(inputs="s3://mybucket/train", job_name="new_name") model = pytorch.create_model(image_uri=custom_image) assert model.image_uri == custom_image + + +def test_pt_heterogeneous_cluster_distribution_config( + sagemaker_session, pytorch_training_version, pytorch_training_py_version +): + training_group = InstanceGroup("train_group", "ml.c4.xlarge", 1) + expected_return = {"mpi": {"enabled": True}, "instance_groups": ["train_group"]} + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=pytorch_training_version, + py_version=pytorch_training_py_version, + instance_groups=[training_group], + distribution={ + "mpi": {"enabled": True}, + "instance_groups": [training_group], + }, + ) + assert pytorch.distribution == expected_return From f7a94b89ef8cc77ffcff7d17dc35f1ba62d2bb0e Mon Sep 17 00:00:00 2001 From: Satish Pasumarthi <35979860+satishpasumarthi@users.noreply.github.com> Date: Fri, 1 Jul 2022 14:24:40 -0700 Subject: [PATCH 109/526] fix: Loosen version of attrs dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 75a113f019..e4c6e2e358 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ - "attrs==20.3.0", + "attrs>=20.3.0,<22", "boto3>=1.20.21,<2.0", "google-pasta", "numpy>=1.9.0,<2.0", From 5326b034931d109b19f5b6dd169b7f6f394c4049 Mon Sep 17 00:00:00 2001 From: jessicazhu3 <106775307+jessicazhu3@users.noreply.github.com> Date: Tue, 5 Jul 2022 19:11:10 -0700 Subject: [PATCH 110/526] fix: image_uri does not need to be specified with instance_groups --- src/sagemaker/estimator.py | 36 ++++++++++++++++++++- tests/unit/test_estimator.py | 62 ++++++++++++++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 937fd132b4..7818a6b753 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -16,6 +16,7 @@ import json import logging import os +import re import uuid from abc import ABCMeta, abstractmethod from typing import Any, Dict, Union, Optional, List @@ -1520,6 +1521,39 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na init_params["max_wait"] = max_wait return init_params + def _get_instance_type(self): + """Determine the instance type to be used in the training_image_uri function. + + Returns: + instance_type: The instance_type to be used. + """ + if self.instance_type is not None: + return self.instance_type + + if not isinstance(self.instance_groups, list) or len(self.instance_groups) == 0: + raise ValueError( + "instance_groups must be set if instance_type is not set and instance_groups " + "must be a list." + ) + + for instance_group in self.instance_groups: + instance_type = instance_group.instance_type + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + + if match: + family = match[1] + if family[0] in ("g", "p"): + return instance_type + else: + raise ValueError( + "Invalid SageMaker instance type for training with heterogeneous clusters: {}. " + "For options see: https://aws.amazon.com/sagemaker/pricing/instance-types".format( + instance_type + ) + ) + + return self.instance_groups[0].instance_type + def transformer( self, instance_count, @@ -2903,7 +2937,7 @@ def training_image_uri(self, region=None): compiler_config=getattr(self, "compiler_config", None), tensorflow_version=getattr(self, "tensorflow_version", None), pytorch_version=getattr(self, "pytorch_version", None), - instance_type=self.instance_type, + instance_type=self._get_instance_type(), ) @classmethod diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 30b80bd58b..d402a509fc 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -1334,6 +1334,68 @@ def test_invalid_custom_code_bucket(sagemaker_session): assert "Expecting 's3' scheme" in str(error) +def test_get_instance_type_gpu(sagemaker_session): + estimator = Estimator( + image_uri="some-image", + role="some_image", + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.p3.16xlarge", 2), + ], + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + + assert "ml.p3.16xlarge" == estimator._get_instance_type() + + +def test_get_instance_type_cpu(sagemaker_session): + estimator = Estimator( + image_uri="some-image", + role="some_image", + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.c5.xlarge", 2), + ], + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + + assert "ml.c4.xlarge" == estimator._get_instance_type() + + +def test_get_instance_type_no_instance_groups(sagemaker_session): + estimator = Estimator( + image_uri="some-image", + role="some_image", + instance_type="ml.c4.xlarge", + instance_count=1, + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + + assert "ml.c4.xlarge" == estimator._get_instance_type() + + +def test_get_instance_type_no_instance_groups_or_instance_type(sagemaker_session): + estimator = Estimator( + image_uri="some-image", + role="some_image", + instance_type=None, + instance_count=None, + instance_groups=None, + sagemaker_session=sagemaker_session, + base_job_name="base_job_name", + ) + with pytest.raises(ValueError) as error: + estimator._get_instance_type() + + assert ( + "instance_groups must be set if instance_type is not set and instance_groups must be a list." + in str(error) + ) + + def test_augmented_manifest(sagemaker_session): fw = DummyFramework( entry_point=SCRIPT_PATH, From 74d68c0a3be04223e55569b30afc3d97e057da4e Mon Sep 17 00:00:00 2001 From: Sumit Awasthi Date: Thu, 7 Jul 2022 13:24:26 -0700 Subject: [PATCH 111/526] fix: Moving the newly added field instance_group to the end of method Moving the newly added field instance_group to the end of method --- src/sagemaker/estimator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7818a6b753..f6efb60cd3 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -146,7 +146,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[str] = None, dependencies: Optional[List[Union[str]]] = None, - instance_groups=None, + instance_groups: Optional[Dict[str, Union[str, int]]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -2140,7 +2140,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[str] = None, dependencies: Optional[List[str]] = None, - instance_groups=None, + instance_groups: Optional[Dict[str, Union[str, int]]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2429,7 +2429,6 @@ def __init__( role, instance_count, instance_type, - instance_groups, volume_size, volume_kms_key, max_run, @@ -2465,6 +2464,7 @@ def __init__( entry_point=entry_point, dependencies=dependencies, hyperparameters=hyperparameters, + instance_groups=instance_groups, **kwargs, ) From b2682269e04d3d1e01098568ffa8146e992e6c6e Mon Sep 17 00:00:00 2001 From: Miyoung Choi Date: Thu, 30 Jun 2022 15:03:27 -0700 Subject: [PATCH 112/526] documentation: documentation for heterogeneous cluster --- doc/api/utility/instance_group.rst | 8 +++ src/sagemaker/estimator.py | 89 +++++++++++++++++++----------- src/sagemaker/inputs.py | 37 ++++++++----- src/sagemaker/instance_group.py | 35 ++++++++---- 4 files changed, 112 insertions(+), 57 deletions(-) create mode 100644 doc/api/utility/instance_group.rst diff --git a/doc/api/utility/instance_group.rst b/doc/api/utility/instance_group.rst new file mode 100644 index 0000000000..141c756b0e --- /dev/null +++ b/doc/api/utility/instance_group.rst @@ -0,0 +1,8 @@ +Instance Group +-------------- + +.. automodule:: sagemaker.instance_group + :members: + :undoc-members: + :show-inheritance: + :private-members: diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index f6efb60cd3..9620ab7408 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -160,7 +160,7 @@ def __init__( instance_count (int): Number of Amazon EC2 instances to use for training. Required if instance_groups is not set. instance_type (str): Type of EC2 instance to use for training, - for example, 'ml.c4.xlarge'. Required if instance_groups is + for example, ``'ml.c4.xlarge'``. Required if instance_groups is not set. volume_size (int): Size in GB of the EBS volume to use for storing input data during training (default: 30). Must be large @@ -235,7 +235,6 @@ def __init__( use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for training. If enabled then the ``max_wait`` arg should also be set. - More information: https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html (default: ``False``). @@ -313,19 +312,18 @@ def __init__( when training on Amazon SageMaker. If 'git_config' is provided, 'source_dir' should be a relative location to a directory in the Git repo. + With the following GitHub repo directory structure: - .. admonition:: Example - - With the following GitHub repo directory structure: + .. code:: - >>> |----- README.md - >>> |----- src - >>> |----- train.py - >>> |----- test.py + |----- README.md + |----- src + |----- train.py + |----- test.py - if you need 'train.py' as the entry point and 'test.py' as - the training source code, you can assign - entry_point='train.py' and source_dir='src'. + if you need 'train.py' as the entry point and 'test.py' as + the training source code, you can assign + entry_point='train.py' and source_dir='src'. git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``, ``commit``, ``2FA_enabled``, ``username``, ``password``, and ``token``. The @@ -333,20 +331,19 @@ def __init__( ``repo`` specifies the Git repository where your training script is stored. If you don't provide ``branch``, the default value 'master' is used. If you don't provide ``commit``, the latest - commit in the specified branch is used. + commit in the specified branch is used. For example, the following config: - .. admonition:: Example - - The following config: - - >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', - >>> 'branch': 'test-branch-git-config', - >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + .. code:: python - results in cloning the repo specified in 'repo', then - checking out the 'master' branch, and checking out the specified - commit. + git_config = { + 'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + 'branch': 'test-branch-git-config', + 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a' + } + results in cloning the repo specified in 'repo', then + checking out the 'master' branch, and checking out the specified + commit. ``2FA_enabled``, ``username``, ``password``, and ``token`` are used for authentication. For GitHub (or other Git) accounts, set ``2FA_enabled`` to 'True' if two-factor authentication is @@ -427,10 +424,25 @@ def __init__( >>> |------ virtual-env This is not supported with "local code" in Local Mode. - instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup - for specifying different instance groups for heterogeneous cluster. - For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64), - sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)] + instance_groups (list[:class:`sagemaker.instance_group.InstanceGroup`]): + Optional. A list of ``InstanceGroup`` objects + for launching a training job with a heterogeneous cluster. + For example: + + .. code:: python + + instance_groups=[ + sagemaker.InstanceGroup( + 'instance_group_name_1', 'ml.p3dn.24xlarge', 64), + sagemaker.InstanceGroup( + 'instance_group_name_2', 'ml.c5n.18xlarge', 64)] + + For instructions on how to use ``InstanceGroup`` objects + to configure a heterogeneous cluster + through the SageMaker generic and framework estimator classes, see + `Train Using a Heterogeneous Cluster + `_ + in the *Amazon SageMaker developer guide*. """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -2418,10 +2430,25 @@ def __init__( >>> |------ virtual-env This is not supported with "local code" in Local Mode. - instance_groups (list[InstanceGroup]): Optional. List of InstanceGroup - for specifying different instance groups for heterogeneous cluster. - For example: [sagemaker.InstanceGroup('worker','ml.p3dn.24xlarge',64), - sagemaker.InstanceGroup('server','ml.c5n.18xlarge',64)] + instance_groups (list[:class:`sagemaker.instance_group.InstanceGroup`]): + Optional. A list of ``InstanceGroup`` objects + for launching a training job with a heterogeneous cluster. + For example: + + .. code:: python + + instance_groups=[ + sagemaker.InstanceGroup( + 'instance_group_name_1', 'ml.p3dn.24xlarge', 64), + sagemaker.InstanceGroup( + 'instance_group_name_2', 'ml.c5n.18xlarge', 64)] + + For instructions on how to use ``InstanceGroup`` objects + to configure a heterogeneous cluster + through the SageMaker generic and framework estimator classes, see + `Train Using a Heterogeneous Cluster + `_ + in the *Amazon SageMaker developer guide*. """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index d2bd94d232..855488d33a 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -41,28 +41,37 @@ def __init__( target_attribute_name=None, shuffle_config=None, ): - """Create a definition for input data used by an SageMaker training job. + r"""Create a definition for input data used by an SageMaker training job. - See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. + See AWS documentation on the ``CreateTrainingJob`` API for more details + on the parameters. Args: - s3_data (str): Defines the location of s3 data to train on. - distribution (str): Valid values: 'FullyReplicated', 'ShardedByS3Key' - (default: 'FullyReplicated'). - compression (str): Valid values: 'Gzip', None (default: None). This is used only in + s3_data (str): Defines the location of S3 data to train on. + distribution (str): Valid values: ``'FullyReplicated'``, + ``'ShardedByS3Key'`` + (default: ``'FullyReplicated'``). + compression (str): Valid values: ``'Gzip'``, ``None`` (default: None). + This is used only in Pipe input mode. content_type (str): MIME type of the input data (default: None). record_wrapping (str): Valid values: 'RecordIO' (default: None). - s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. - If 'S3Prefix', ``s3_data`` defines a prefix of s3 objects to train on. + s3_data_type (str): Valid values: ``'S3Prefix'``, ``'ManifestFile'``, + ``'AugmentedManifestFile'``. + If ``'S3Prefix'``, ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will be used to train. - If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data`` defines a - single S3 manifest file or augmented manifest file (respectively), + If ``'ManifestFile'`` or ``'AugmentedManifestFile'``, + then ``s3_data`` defines a + single S3 manifest file or augmented manifest file respectively, listing the S3 data to train on. Both the ManifestFile and - AugmentedManifestFile formats are described in the SageMaker API documentation: - https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html - instance_groups (list[str]): Optional. List of InstanceGroupNames to send data to - (default: None). By default, data will be sent to all groups. + AugmentedManifestFile formats are described at `S3DataSource + `_ + in the `Amazon SageMaker API reference`. + instance_groups (list[str]): Optional. A list of ``instance_group_name``\ s + of a heterogeneous cluster that's configured using the + :class:`sagemaker.instance_group.InstanceGroup`. + S3 data will be sent to all instance groups in the specified list. + (default: None) input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore diff --git a/src/sagemaker/instance_group.py b/src/sagemaker/instance_group.py index 005a39433b..5042787be5 100644 --- a/src/sagemaker/instance_group.py +++ b/src/sagemaker/instance_group.py @@ -10,15 +10,12 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This file defines instance group for heterogeneous cluster.""" +"""Defines the InstanceGroup class that configures a heterogeneous cluster.""" from __future__ import absolute_import class InstanceGroup(object): - """Accepts instance group parameters for conversion to request dict. - - The `_to_request_dict` provides a method to turn the parameters into a dict. - """ + """The class to create instance groups for a heterogeneous cluster.""" def __init__( self, @@ -26,16 +23,30 @@ def __init__( instance_type=None, instance_count=None, ): - """Initialize a ``InstanceGroup`` instance. + """It initializes an ``InstanceGroup`` instance. + + You can create instance group object of the ``InstanceGroup`` class + by specifying the instance group configuration arguments. - InstanceGroup accepts instance group parameters and provides a method to turn - these parameters into a dictionary. + For instructions on how to use InstanceGroup objects + to configure a heterogeneous cluster + through the SageMaker generic and framework estimator classes, see + `Train Using a Heterogeneous Cluster + `_ + in the *Amazon SageMaker developer guide*. Args: - instance_group_name (str): Name of the instance group. - instance_type (str): Type of EC2 instance to use in the instance group, - for example, 'ml.c4.xlarge'. - instance_count (int): Number of EC2 instances to use in the instance group. + instance_group_name (str): The name of the instance group. + instance_type (str): The instance type to use in the instance group. + instance_count (int): The number of instances to use in the instance group. + + .. tip:: + + For more information about available values for the arguments, + see `InstanceGroup + `_ + API in the `Amazon SageMaker API reference`. + """ self.instance_group_name = instance_group_name self.instance_type = instance_type From bfebe63428c7d762aca6b8a8ccfe3ed8a5f915bc Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 8 Jul 2022 05:22:37 +0000 Subject: [PATCH 113/526] prepare release v2.99.0 --- CHANGELOG.md | 20 ++++++++++++++++++++ VERSION | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d2e8a82d85..2987f13ac0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## v2.99.0 (2022-07-08) + +### Features + + * heterogeneous cluster set up in distribution config + * support heterogeneous cluster for training + * include fields to work with inference recommender + +### Bug Fixes and Other Changes + + * Moving the newly added field instance_group to the end of method + * image_uri does not need to be specified with instance_groups + * Loosen version of attrs dependency + * Add PipelineVariable annotation in estimatory, processing, tuner, transformer base classes + * model table link + +### Documentation Changes + + * documentation for heterogeneous cluster + ## v2.98.0 (2022-07-05) ### Features diff --git a/VERSION b/VERSION index 06d7460288..47d8253c5e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.98.1.dev0 +2.99.0 From 712c963692d7c811756c8cd90dbe804bb308fa5e Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 8 Jul 2022 05:22:38 +0000 Subject: [PATCH 114/526] update development version to v2.99.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 47d8253c5e..60c21330a6 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.99.0 +2.99.1.dev0 From 1164c60848c319eaec746d8165e5b529eba53b27 Mon Sep 17 00:00:00 2001 From: Samsara Counts Date: Fri, 8 Jul 2022 10:16:23 -0700 Subject: [PATCH 115/526] documentation: add detail & links to clarify docstrings (#3216) --- src/sagemaker/clarify.py | 604 +++++++++++++++++++++++---------------- 1 file changed, 352 insertions(+), 252 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 5e2922c395..eaf78069c3 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -10,7 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""This module configures the SageMaker Clarify bias and model explainability processor job.""" +"""This module configures the SageMaker Clarify bias and model explainability processor jobs. + +SageMaker Clarify +================== +""" from __future__ import absolute_import, print_function import copy @@ -47,22 +51,25 @@ def __init__( Args: s3_data_input_path (str): Dataset S3 prefix/object URI. s3_output_path (str): S3 prefix to store the output. - s3_analysis_config_output_path (str): S3 prefix to store the analysis_config output - If this field is None, then the s3_output_path will be used - to store the analysis_config output - label (str): Target attribute of the model required by bias metrics (optional for SHAP) + s3_analysis_config_output_path (str): S3 prefix to store the analysis config output. + If this field is None, then the ``s3_output_path`` will be used + to store the ``analysis_config`` output. + label (str): Target attribute of the model **required** for bias metrics (both pre- + and post-training). Optional when running SHAP explainability. Specified as column name or index for CSV dataset, or as JSONPath for JSONLines. headers (list[str]): A list of column names in the input dataset. features (str): JSONPath for locating the feature columns for bias metrics if the dataset format is JSONLines. - dataset_type (str): Format of the dataset. Valid values are "text/csv" for CSV, - "application/jsonlines" for JSONLines, and "application/x-parquet" for Parquet. - s3_compression_type (str): Valid options are "None" or "Gzip". + dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, + ``"application/jsonlines"`` for JSONLines, and + ``"application/x-parquet"`` for Parquet. + s3_compression_type (str): Valid options are "None" or ``"Gzip"``. joinsource (str): The name or index of the column in the dataset that acts as an identifier column (for instance, while performing a join). This column is only used as an identifier, and not used for any other computations. This is an optional field in all cases except when the dataset contains more than one file, - and `save_local_shap_values` is set to true in SHAPConfig. + and ``save_local_shap_values`` is set to True + in :class:`~sagemaker.clarify.SHAPConfig`. """ if dataset_type not in [ "text/csv", @@ -96,7 +103,7 @@ def get_config(self): class BiasConfig: - """Config object related to bias configurations of the input dataset.""" + """Config object with user-defined bias configurations of the input dataset.""" def __init__( self, @@ -109,36 +116,44 @@ def __init__( Args: label_values_or_threshold ([int or float or str]): List of label value(s) or threshold - to indicate positive outcome used for bias metrics. Dependency on the problem type, + to indicate positive outcome used for bias metrics. + The appropriate threshold depends on the problem type: - * Binary problem: The list shall include one positive value. - * Categorical problem: The list shall include one or more (but not all) categories + * Binary: The list has one positive value. + * Categorical:The list has one or more (but not all) categories which are the positive values. - * Regression problem: The list shall include one threshold that defines the lower - bound of positive values. + * Regression: The list should include one threshold that defines the **exclusive** + lower bound of positive values. - facet_name (str or int or [str] or [int]): Sensitive attribute column name (or index in - the input data) for which you like to compute bias metrics. It can also be a list - of names (or indexes) if you like to compute for multiple sensitive attributes. + facet_name (str or int or list[str] or list[int]): Sensitive attribute column name + (or index in the input data) to use when computing bias metrics. It can also be a + list of names (or indexes) for computing metrics for multiple sensitive attributes. facet_values_or_threshold ([int or float or str] or [[int or float or str]]): - The parameter indicates the sensitive group. If facet_name is a scalar, then it can - be None or a list. Depending on the data type of the facet column, - - * Binary: None means computing the bias metrics for each binary value. Or add one - binary value to the list, to compute its bias metrics only. - * Categorical: None means computing the bias metrics for each category. Or add one - or more (but not all) categories to the list, to compute their bias metrics v.s. - the other categories. - * Continuous: The list shall include one and only one threshold which defines the - lower bound of a sensitive group. - - If facet_name is a list, then it can be None if all facets are of binary type or - categorical type. Otherwise it shall be a list, and each element is the values or - threshold of the corresponding facet. + The parameter controls the values of the sensitive group. + If ``facet_name`` is a scalar, then it can be None or a list. + Depending on the data type of the facet column, the values mean: + + * Binary data: None means computing the bias metrics for each binary value. + Or add one binary value to the list, to compute its bias metrics only. + * Categorical data: None means computing the bias metrics for each category. Or add + one or more (but not all) categories to the list, to compute their + bias metrics v.s. the other categories. + * Continuous data: The list should include one and only one threshold which defines + the **exclusive** lower bound of a sensitive group. + + If ``facet_name`` is a list, then ``facet_values_or_threshold`` can be None + if all facets are of binary or categorical type. + Otherwise, ``facet_values_or_threshold`` should be a list, and each element + is the value or threshold of the corresponding facet. group_name (str): Optional column name or index to indicate a group column to be used - for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or - 'Conditional Demographic Disparity in Predicted Labels - CDDPL'. - """ + for the bias metric + `Conditional Demographic Disparity in Labels `(CDDL) `_ + or + `Conditional Demographic Disparity in Predicted Labels (CDDPL) `_. + + Raises: + ValueError: If the number of ``facet_names`` doesn't equal number of ``facet values`` + """ # noqa E501 # pylint: disable=c0301 if isinstance(facet_name, list): assert len(facet_name) > 0, "Please provide at least one facet" if facet_values_or_threshold is None: @@ -167,7 +182,7 @@ def __init__( _set(group_name, "group_variable", self.analysis_config) def get_config(self): - """Returns part of an analysis config dictionary.""" + """Returns a dictionary of bias detection configurations, part of the analysis config""" return copy.deepcopy(self.analysis_config) @@ -192,7 +207,7 @@ def __init__( model_name (str): Model name (as created by 'CreateModel'). instance_count (int): The number of instances of a new endpoint for model inference. instance_type (str): The type of EC2 instance to use for model inference, - for example, 'ml.c5.xlarge'. + for example, ``"ml.c5.xlarge"``. accept_type (str): The model output format to be used for getting inferences with the shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines". Default is the same as content_type. @@ -200,9 +215,9 @@ def __init__( shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines". Default is the same as dataset format. content_template (str): A template string to be used to construct the model input from - dataset instances. It is only used when "model_content_type" is - "application/jsonlines". The template should have one and only one placeholder - $features which will be replaced by a features list for to form the model inference + dataset instances. It is only used when ``model_content_type`` is + ``"application/jsonlines"``. The template should have one and only one placeholder, + "features", which will be replaced by a features list to form the model inference input. custom_attributes (str): Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The @@ -210,14 +225,19 @@ def __init__( value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in - Section 3.3.6. Field Value Components ( - https://tools.ietf.org/html/rfc7230#section-3.2.6) of the Hypertext Transfer - Protocol (HTTP/1.1). - accelerator_type (str): The Elastic Inference accelerator type to deploy to the model - endpoint instance for making inferences to the model, see - https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html. + Section 3.3.6. + `Field Value Components `_ + of the Hypertext Transfer Protocol (HTTP/1.1). + accelerator_type (str): SageMaker + `Elastic Inference `_ + accelerator type to deploy to the model endpoint instance + for making inferences to the model. endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow - pattern "^[a-zA-Z0-9](-\*[a-zA-Z0-9]". + pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``. + + Raises: + ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid, + ``content_type`` is invalid, or ``content_template`` has no placeholder "features" """ self.predictor_config = { "model_name": model_name, @@ -280,28 +300,30 @@ def __init__( """Initializes a model output config to extract the predicted label or predicted score(s). The following examples show different parameter configurations depending on the endpoint: - * Regression Task: The model returns the score, e.g. 1.2. we don't need to specify - anything. For json output, e.g. {'score': 1.2} we can set 'label='score''. - - * Binary classification: - * The model returns a single probability and we would like to classify as 'yes' - those with a probability exceeding 0.2. - We can set 'probability_threshold=0.2, label_headers='yes''. - * The model returns {'probability': 0.3}, for which we would like to apply a - threshold of 0.5 to obtain a predicted label in {0, 1}. In this case we can set - 'label='probability''. - * The model returns a tuple of the predicted label and the probability. - In this case we can set 'label=0'. - - * Multiclass classification: - * The model returns - {'labels': ['cat', 'dog', 'fish'], 'probabilities': [0.35, 0.25, 0.4]}. - In this case we would set the 'probability='probabilities'' and - 'label='labels'' and infer the predicted label to be 'fish.' - * The model returns {'predicted_label': 'fish', 'probabilities': [0.35, 0.25, 0.4]}. - In this case we would set the 'label='predicted_label''. - * The model returns [0.35, 0.25, 0.4]. In this case, we can set - 'label_headers=['cat','dog','fish']' and infer the predicted label to be 'fish.' + + * **Regression task:** + The model returns the score, e.g. ``1.2``. We don't need to specify + anything. For json output, e.g. ``{'score': 1.2}``, we can set ``label='score'``. + * **Binary classification:** + + * The model returns a single probability score. We want to classify as ``"yes"`` + predictions with a probability score over ``0.2``. + We can set ``probability_threshold=0.2`` and ``label_headers="yes"``. + * The model returns ``{"probability": 0.3}``, for which we would like to apply a + threshold of ``0.5`` to obtain a predicted label in ``{0, 1}``. + In this case we can set ``label="probability"``. + * The model returns a tuple of the predicted label and the probability. + In this case we can set ``label = 0``. + * **Multiclass classification:** + + * The model returns ``{'labels': ['cat', 'dog', 'fish'], + 'probabilities': [0.35, 0.25, 0.4]}``. In this case we would set + ``probability='probabilities'``, ``label='labels'``, + and infer the predicted label to be ``'fish'``. + * The model returns ``{'predicted_label': 'fish', 'probabilities': [0.35, 0.25, 0.4]}``. + In this case we would set the ``label='predicted_label'``. + * The model returns ``[0.35, 0.25, 0.4]``. In this case, we can set + ``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``. Args: label (str or int): Index or JSONPath location in the model output for the prediction. @@ -311,11 +333,14 @@ def __init__( for the predicted score(s). probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the - prediction to a boolean value. Default is 0.5. + prediction to a boolean value. Default is ``0.5``. label_headers (list[str]): List of headers, each for a predicted score in model output. For bias analysis, it is used to extract the label value with the highest score as - predicted label. For explainability job, It is used to beautify the analysis report - by replacing placeholders like "label0". + predicted label. For explainability jobs, it is used to beautify the analysis report + by replacing placeholders like ``'label0'``. + + Raises: + TypeError: when the ``probability_threshold`` cannot be cast to a float """ self.label = label self.probability = probability @@ -335,7 +360,7 @@ def __init__( _set(label_headers, "label_headers", self.predictor_config) def get_predictor_config(self): - """Returns probability_threshold, predictor config.""" + """Returns ``probability_threshold`` and predictor config dictionary.""" return self.probability_threshold, copy.deepcopy(self.predictor_config) @@ -351,41 +376,52 @@ def get_explainability_config(self): class PDPConfig(ExplainabilityConfig): """Config class for Partial Dependence Plots (PDP). - If PDP is requested, the Partial Dependence Plots will be included in the report, and the - corresponding values will be included in the analysis output. - """ + `PDPs `_ + show the marginal effect (the dependence) a subset of features has on the predicted + outcome of an ML model. + + When PDP is requested (by passing in a :class:`~sagemaker.clarify.PDPConfig` to the + ``explainability_config`` parameter of :class:`~sagemaker.clarify.SageMakerClarifyProcessor`), + the Partial Dependence Plots are included in the output + `report `__ + and the corresponding values are included in the analysis output. + """ # noqa E501 def __init__(self, features=None, grid_resolution=15, top_k_features=10): - """Initializes config for PDP. + """Initializes PDP config. Args: - features (None or list): List of features names or indices for which partial dependence - plots must be computed and plotted. When ShapConfig is provided, this parameter is - optional as Clarify will try to compute the partial dependence plots for top - feature based on SHAP attributions. When ShapConfig is not provided, 'features' - must be provided. - grid_resolution (int): In case of numerical features, this number represents that - number of buckets that range of values must be divided into. This decides the + features (None or list): List of feature names or indices for which partial dependence + plots are computed and plotted. When :class:`~sagemaker.clarify.ShapConfig` + is provided, this parameter is optional, as Clarify will compute the + partial dependence plots for top features based on + `SHAP `__ + attributions. When :class:`~sagemaker.clarify.ShapConfig` is not provided, + ``features`` must be provided. + grid_resolution (int): When using numerical features, this integer represents the + number of buckets that the range of values must be divided into. This decides the granularity of the grid in which the PDP are plotted. - top_k_features (int): Set the number of top SHAP attributes to be selected to compute + top_k_features (int): Sets the number of top SHAP attributes used to compute partial dependence plots. - """ + """ # noqa E501 self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features} if features is not None: self.pdp_config["features"] = features def get_explainability_config(self): - """Returns config.""" + """Returns PDP config dictionary.""" return copy.deepcopy({"pdp": self.pdp_config}) class TextConfig: - """Config object to handle text features. + """Config object to handle text features for text explainability - The SHAP analysis will break down longer text into chunks (e.g. tokens, sentences, or paragraphs - ) and replace them with the strings specified in the baseline for that feature. The shap value + `SHAP analysis `__ + breaks down longer text into chunks (e.g. tokens, sentences, or paragraphs) + and replaces them with the strings specified in the baseline for that feature. + The `shap value `_ of a chunk then captures how much replacing it affects the prediction. - """ + """ # noqa E501 # pylint: disable=c0301 _SUPPORTED_GRANULARITIES = ["token", "sentence", "paragraph"] _SUPPORTED_LANGUAGES = [ @@ -461,19 +497,28 @@ def __init__( ): """Initializes a text configuration. - Args: granularity (str): Determines the granularity in which text features are broken down - to, can be "token", "sentence", or "paragraph". Shap values are computed for these units. - language (str): Specifies the language of the text features, can be "chinese", "danish", - "dutch", "english", "french", "german", "greek", "italian", "japanese", "lithuanian", - "multi-language", "norwegian bokmål", "polish", "portuguese", "romanian", "russian", - "spanish", "afrikaans", "albanian", "arabic", "armenian", "basque", "bengali", "bulgarian", - "catalan", "croatian", "czech", "estonian", "finnish", "gujarati", "hebrew", "hindi", - "hungarian", "icelandic", "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian", - "luxembourgish", "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit", - "serbian", "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil", - "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". Use - "multi-language" for a mix of mulitple languages. - """ + Args: + granularity (str): Determines the granularity in which text features are broken down + to. Accepted values are ``"token"``, ``"sentence"``, or ``"paragraph"``. + Computes `shap values `_ + for these units. + language (str): Specifies the language of the text features. Accepted values are + one of the following: + "chinese", "danish", "dutch", "english", "french", "german", "greek", "italian", + "japanese", "lithuanian", "multi-language", "norwegian bokmål", "polish", + "portuguese", "romanian", "russian", "spanish", "afrikaans", "albanian", "arabic", + "armenian", "basque", "bengali", "bulgarian", "catalan", "croatian", "czech", + "estonian", "finnish", "gujarati", "hebrew", "hindi", "hungarian", "icelandic", + "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian", "luxembourgish", + "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit", "serbian", + "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil", + "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". + Use "multi-language" for a mix of multiple languages. + + Raises: + ValueError: when ``granularity`` is not in list of supported values + or ``language`` is not in list of supported values + """ # noqa E501 # pylint: disable=c0301 if granularity not in TextConfig._SUPPORTED_GRANULARITIES: raise ValueError( f"Invalid granularity {granularity}. Please choose among " @@ -490,7 +535,7 @@ def __init__( } def get_text_config(self): - """Returns part of an analysis config dictionary.""" + """Returns a text config dictionary, part of the analysis config dictionary.""" return copy.deepcopy(self.text_config) @@ -507,32 +552,46 @@ def __init__( iou_threshold=None, context=None, ): - """Initializes all configuration parameters needed for SHAP CV explainability + """Initializes a config object for Computer Vision (CV) Image explainability. + + `SHAP for CV explainability `__. + generating heat maps that visualize feature attributions for input images. + These heat maps highlight the image's features according + to how much they contribute to the CV model prediction. + + ``"IMAGE_CLASSIFICATION"`` and ``"OBJECT_DETECTION"`` are the two supported CV use cases. Args: - model_type (str): Specifies the type of CV model. Options: - (IMAGE_CLASSIFICATION | OBJECT_DETECTION). - num_segments (None or int): Clarify uses SKLearn's SLIC method for image segmentation - to generate features/superpixels. num_segments specifies approximate - number of segments to be generated. Default is None. SLIC will default to - 100 segments. + model_type (str): Specifies the type of CV model and use case. Accepted options: + ``"IMAGE_CLASSIFICATION"`` or ``"OBJECT_DETECTION"``. + num_segments (None or int): Approximate number of segments to generate when running + SKLearn's `SLIC method `_ + for image segmentation to generate features/superpixels. + The default is None. When set to None, runs SLIC with 20 segments. feature_extraction_method (None or str): method used for extracting features from the - image.ex. "segmentation". Default is segmentation. + image (ex: "segmentation"). Default is ``"segmentation"``. segment_compactness (None or float): Balances color proximity and space proximity. - Higher values give more weight to space proximity, making superpixel - shapes more square/cubic. We recommend exploring possible values on a log - scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value. - max_objects (None or int): maximum number of objects displayed. Object detection - algorithm may detect more than max_objects number of objects in a single - image. The top max_objects number of objects according to confidence score - will be displayed. - iou_threshold (None or float): minimum intersection over union for the object - bounding box to consider its confidence score for computing SHAP values [0.0, 1.0]. - This parameter is used for the object detection case. - context (None or float): refers to the portion of the image outside of the bounding box. - Scale is [0.0, 1.0]. If set to 1.0, whole image is considered, if set to - 0.0 only the image inside bounding box is considered. - """ + Higher values give more weight to space proximity, making superpixel + shapes more square/cubic. We recommend exploring possible values on a log + scale, e.g., 0.01, 0.1, 1, 10, 100, before refining around a chosen value. + The default is None. When set to None, runs with the default value of ``5``. + max_objects (None or int): Maximum number of objects displayed when running SHAP + with an ``"OBJECT_DETECTION"`` model. The Object detection algorithm may detect + more than the ``max_objects`` number of objects in a single image. + In that case, the algorithm displays the top ``max_objects`` number of objects + according to confidence score. Default value is None. In the ``"OBJECT_DETECTION"`` + case, passing in None leads to a default value of ``3``. + iou_threshold (None or float): Minimum intersection over union for the object + bounding box to consider its confidence score for computing SHAP values, + in the range ``[0.0, 1.0]``. Used only for the ``"OBJECT_DETECTION"`` case, + where passing in None sets the default value of ``0.5``. + context (None or float): The portion of the image outside the bounding box used + in SHAP analysis, in the range ``[0.0, 1.0]``. If set to ``1.0``, the whole image + is considered; if set to ``0.0`` only the image inside bounding box is considered. + Only used for the ``"OBJECT_DETECTION"`` case, + when passing in None sets the default value of ``1.0``. + + """ # noqa E501 # pylint: disable=c0301 self.image_config = {} if model_type not in ["OBJECT_DETECTION", "IMAGE_CLASSIFICATION"]: @@ -554,7 +613,15 @@ def get_image_config(self): class SHAPConfig(ExplainabilityConfig): - """Config class of SHAP.""" + """Config class for `SHAP `__. + + The SHAP algorithm calculates feature attributions by computing + the contribution of each feature to the prediction outcome, using the concept of + `Shapley values `_. + + These attributions can be provided for specific predictions (locally) + and at a global level for the model as a whole. + """ # noqa E501 # pylint: disable=c0301 def __init__( self, @@ -568,38 +635,41 @@ def __init__( text_config=None, image_config=None, ): - """Initializes config for SHAP. + """Initializes config for SHAP analysis. Args: - baseline (None or str or list): None or S3 object Uri or A list of rows (at least one) - to be used asthe baseline dataset in the Kernel SHAP algorithm. The format should - be the same as the dataset format. Each row should contain only the feature - columns/values and omit the label column/values. If None a baseline will be - calculated automatically by using K-means or K-prototypes in the input dataset. + baseline (None or str or list): `Baseline dataset `_ + for the Kernel SHAP algorithm, accepted in the form of: + S3 object URI, a list of rows (with at least one element), + or None (for no input baseline). The baseline dataset must have the same format + as the input dataset specified in :class:`~sagemaker.clarify.DataConfig`. + Each row must have only the feature columns/values and omit the label column/values. + If None, a baseline will be calculated automatically on the input dataset + using K-means (for numerical data) or K-prototypes (if there is categorical data). num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm. This number determines the size of the generated synthetic dataset to compute the SHAP values. If not provided then Clarify job will choose a proper value according to the count of features. agg_method (None or str): Aggregation method for global SHAP values. Valid values are - "mean_abs" (mean of absolute SHAP values for all instances), - "median" (median of SHAP values for all instances) and - "mean_sq" (mean of squared SHAP values for all instances). - If not provided then Clarify job uses method "mean_abs" - use_logit (bool): Indicator of whether the logit function is to be applied to the model - predictions. Default is False. If "use_logit" is true then the SHAP values will + ``"mean_abs"`` (mean of absolute SHAP values for all instances), + ``"median"`` (median of SHAP values for all instances) and + ``"mean_sq"`` (mean of squared SHAP values for all instances). + If None is provided, then Clarify job uses the method ``"mean_abs"``. + use_logit (bool): Indicates whether to apply the logit function to model predictions. + Default is False. If ``use_logit`` is true then the SHAP values will have log-odds units. - save_local_shap_values (bool): Indicator of whether to save the local SHAP values + save_local_shap_values (bool): Indicates whether to save the local SHAP values in the output location. Default is True. - seed (int): seed value to get deterministic SHAP values. Default is None. - num_clusters (None or int): If a baseline is not provided, Clarify automatically - computes a baseline dataset via a clustering algorithm (K-means/K-prototypes). - num_clusters is a parameter for this algorithm. num_clusters will be the resulting - size of the baseline dataset. If not provided, Clarify job will use a default value. - text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features. - Default is None - image_config (:class:`~sagemaker.clarify.ImageConfig`): Config to handle image features. - Default is None - """ + seed (int): Seed value to get deterministic SHAP values. Default is None. + num_clusters (None or int): If a ``baseline`` is not provided, Clarify automatically + computes a baseline dataset via a clustering algorithm (K-means/K-prototypes), which + takes ``num_clusters`` as a parameter. ``num_clusters`` will be the resulting size + of the baseline dataset. If not provided, Clarify job uses a default value. + text_config (:class:`~sagemaker.clarify.TextConfig`): Config object for handling + text features. Default is None. + image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image + features. Default is None. + """ # noqa E501 # pylint: disable=c0301 if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]: raise ValueError( f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq." @@ -630,12 +700,12 @@ def __init__( _set(image_config.get_image_config(), "image_config", self.shap_config) def get_explainability_config(self): - """Returns config.""" + """Returns a shap config dictionary.""" return copy.deepcopy({"shap": self.shap_config}) class SageMakerClarifyProcessor(Processor): - """Handles SageMaker Processing task to compute bias metrics and explain a model.""" + """Handles SageMaker Processing tasks to compute bias metrics and model explanations.""" _CLARIFY_DATA_INPUT = "/opt/ml/processing/input/data" _CLARIFY_CONFIG_INPUT = "/opt/ml/processing/input/config" @@ -657,7 +727,9 @@ def __init__( job_name_prefix=None, version=None, ): - """Initializes a ``Processor`` instance, computing bias metrics and model explanations. + """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations. + + Instance of :class:`~sagemaker.processing.Processor`. Args: role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing @@ -666,7 +738,7 @@ def __init__( instance_count (int): The number of instances to run a processing job with. instance_type (str): The type of EC2 instance to use for - processing, for example, 'ml.c4.xlarge'. + processing, for example, ``'ml.c4.xlarge'``. volume_size_in_gb (int): Size in GB of the EBS volume to use for storing data during processing (default: 30). volume_kms_key (str): A KMS key for the processing @@ -674,12 +746,13 @@ def __init__( output_kms_key (str): The KMS key ID for processing job outputs (default: None). max_runtime_in_seconds (int): Timeout in seconds (default: None). After this amount of time, Amazon SageMaker terminates the job, - regardless of its current status. If `max_runtime_in_seconds` is not - specified, the default value is 24 hours. + regardless of its current status. If ``max_runtime_in_seconds`` is not + specified, the default value is ``86400`` seconds (24 hours). sagemaker_session (:class:`~sagemaker.session.Session`): - Session object which manages interactions with Amazon SageMaker and - any other AWS services needed. If not specified, the processor creates - one using the default AWS configuration chain. + :class:`~sagemaker.session.Session` object which manages interactions + with Amazon SageMaker and any other AWS services needed. If not specified, + the Processor creates a :class:`~sagemaker.session.Session` + using the default AWS configuration chain. env (dict[str, str]): Environment variables to be passed to the processing jobs (default: None). tags (list[dict]): List of tags to be passed to the processing job @@ -690,7 +763,7 @@ def __init__( object that configures network isolation, encryption of inter-container traffic, security group IDs, and subnets. job_name_prefix (str): Processing job name prefix. - version (str): Clarify version want to be used. + version (str): Clarify version to use. """ container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) self.job_name_prefix = job_name_prefix @@ -728,7 +801,9 @@ def _run( kms_key, experiment_config, ): - """Runs a ProcessingJob with the Sagemaker Clarify container and an analysis config. + """Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container + + and analysis config. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. @@ -741,15 +816,16 @@ def _run( user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + The behavior of setting these keys is as follows: - * If `ExperimentName` is supplied but `TrialName` is not a Trial will be - automatically created and the job's Trial Component associated with the Trial. - * If `TrialName` is supplied and the Trial already exists the job's Trial Component - will be associated with the Trial. - * If both `ExperimentName` and `TrialName` are not supplied the trial component - will be unassociated. - * `TrialComponentDisplayName` is used for display in Studio. + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ analysis_config["methods"]["report"] = { "name": "report", @@ -810,15 +886,15 @@ def run_pre_training_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the pre-training bias methods of the input data. + """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods - Computes the requested methods that compare 'methods' (e.g. fraction of examples) for the - sensitive group vs the other examples. + Computes the requested ``methods`` on the input data. The ``methods`` compare + metrics (e.g. fraction of examples) for the sensitive group(s) vs. the other examples. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. data_bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. - methods (str or list[str]): Selector of a subset of potential metrics: + methods (str or list[str]): Selects a subset of potential metrics: ["`CI `_", "`DPL `_", "`KL `_", @@ -831,24 +907,26 @@ def run_pre_training_bias( wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. When ``job_name`` is not specified, if - ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name - will be composed of ``job_name_prefix`` and current timestamp; otherwise use - "Clarify-Pretraining-Bias" as prefix. + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is + specified, the job name will be the ``job_name_prefix`` and current timestamp; + otherwise use ``"Clarify-Pretraining-Bias"`` as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + The behavior of setting these keys is as follows: - * If `ExperimentName` is supplied but `TrialName` is not a Trial will be - automatically created and the job's Trial Component associated with the Trial. - * If `TrialName` is supplied and the Trial already exists the job's Trial Component - will be associated with the Trial. - * If both `ExperimentName` and `TrialName` are not supplied the trial component - will be unassociated. - * `TrialComponentDisplayName` is used for display in Studio. - """ # noqa E501 + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 analysis_config = data_config.get_config() analysis_config.update(data_bias_config.get_config()) analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} @@ -880,12 +958,13 @@ def run_post_training_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the post-training bias methods of the model predictions. + """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias - Spins up a model endpoint, runs inference over the input example in the - 's3_data_input_path' to obtain predicted labels. Computes a the requested methods that - compare 'methods' (e.g. accuracy, precision, recall) for the sensitive group vs the other - examples. + Spins up a model endpoint and runs inference over the input dataset in + the ``s3_data_input_path`` (from the :class:`~sagemaker.clarify.DataConfig`) to obtain + predicted labels. Using model predictions, computes the requested posttraining bias + ``methods`` that compare metrics (e.g. accuracy, precision, recall) for the + sensitive group(s) versus the other examples. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. @@ -910,24 +989,26 @@ def run_post_training_bias( wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. When ``job_name`` is not specified, if - ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name - will be composed of ``job_name_prefix`` and current timestamp; otherwise use - "Clarify-Posttraining-Bias" as prefix. + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + is specified, the job name will be the ``job_name_prefix`` and current timestamp; + otherwise use ``"Clarify-Posttraining-Bias"`` as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + The behavior of setting these keys is as follows: - * If `ExperimentName` is supplied but `TrialName` is not a Trial will be - automatically created and the job's Trial Component associated with the Trial. - * If `TrialName` is supplied and the Trial already exists the job's Trial Component - will be associated with the Trial. - * If both `ExperimentName` and `TrialName` are not supplied the trial component - will be unassociated. - * `TrialComponentDisplayName` is used for display in Studio. - """ + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 analysis_config = data_config.get_config() analysis_config.update(data_bias_config.get_config()) ( @@ -967,11 +1048,12 @@ def run_bias( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob to compute the requested bias methods. + """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods - It computes the metrics of both the pre-training methods and the post-training methods. - To calculate post-training methods, it needs to spin up a model endpoint, runs inference - over the input example in the 's3_data_input_path' to obtain predicted labels. + Computes metrics for both the pre-training and the post-training methods. + To calculate post-training methods, it spins up a model endpoint and runs inference over the + input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`) + to obtain predicted labels. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. @@ -1006,24 +1088,26 @@ def run_bias( wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. When ``job_name`` is not specified, if - ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name - will be composed of ``job_name_prefix`` and current timestamp; otherwise use - "Clarify-Bias" as prefix. + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` is + specified, the job name will be ``job_name_prefix`` and the current timestamp; + otherwise use ``"Clarify-Bias"`` as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + The behavior of setting these keys is as follows: - * If `ExperimentName` is supplied but `TrialName` is not a Trial will be - automatically created and the job's Trial Component associated with the Trial. - * If `TrialName` is supplied and the Trial already exists the job's Trial Component - will be associated with the Trial. - * If both `ExperimentName` and `TrialName` are not supplied the trial component - will be unassociated. - * `TrialComponentDisplayName` is used for display in Studio. - """ # noqa E501 + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 analysis_config = data_config.get_config() analysis_config.update(bias_config.get_config()) analysis_config["predictor"] = model_config.get_predictor_config() @@ -1068,50 +1152,65 @@ def run_explainability( kms_key=None, experiment_config=None, ): - """Runs a ProcessingJob computing for each example in the input the feature importance. - - Currently, only SHAP is supported as explainability method. + """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions. Spins up a model endpoint. - For each input example in the 's3_data_input_path' the SHAP algorithm determines - feature importance, by creating 'num_samples' copies of the example with a subset - of features replaced with values from the 'baseline'. - Model inference is run to see how the prediction changes with the replaced features. - If the model output returns multiple scores importance is computed for each of them. - Across examples, feature importance is aggregated using 'agg_method'. + + Currently, only SHAP and Partial Dependence Plots (PDP) are supported + as explainability methods. + + When SHAP is requested in the ``explainability_config``, + the SHAP algorithm calculates the feature importance for each input example + in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`, + by creating ``num_samples`` copies of the example with a subset of features + replaced with values from the ``baseline``. + It then runs model inference to see how the model's prediction changes with the replaced + features. If the model output returns multiple scores importance is computed for each score. + Across examples, feature importance is aggregated using ``agg_method``. + + When PDP is requested in the ``explainability_config``, + the PDP algorithm calculates the dependence of the target response + on the input features and marginalizes over the values of all other input features. + The Partial Dependence Plots are included in the output + `report `__ + and the corresponding values are included in the analysis output. Args: data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): - Config of the specific explainability method or a list of ExplainabilityConfig - objects. Currently, SHAP and PDP are the two methods supported. + Config of the specific explainability method or a list of + :class:`~sagemaker.clarify.ExplainabilityConfig` objects. + Currently, SHAP and PDP are the two methods supported. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Index or JSONPath to locate the predicted scores in the model output. This is not required if the model output is a single score. Alternatively, it can be an instance - of ModelPredictedLabelConfig to provide more parameters like label_headers. + of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. Only meaningful when ``wait`` is True (default: True). - job_name (str): Processing job name. When ``job_name`` is not specified, if - ``job_name_prefix`` in :class:`SageMakerClarifyProcessor` specified, the job name - will be composed of ``job_name_prefix`` and current timestamp; otherwise use - "Clarify-Explainability" as prefix. + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + is specified, the job name will be composed of ``job_name_prefix`` and current + timestamp; otherwise use ``"Clarify-Explainability"`` as prefix. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). experiment_config (dict[str, str]): Experiment management configuration. Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + The behavior of setting these keys is as follows: - * If `ExperimentName` is supplied but `TrialName` is not a Trial will be - automatically created and the job's Trial Component associated with the Trial. - * If `TrialName` is supplied and the Trial already exists the job's Trial Component - will be associated with the Trial. - * If both `ExperimentName` and `TrialName` are not supplied the trial component - will be unassociated. - * `TrialComponentDisplayName` is used for display in Studio. - """ + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() if isinstance(model_scores, ModelPredictedLabelConfig): @@ -1165,20 +1264,21 @@ def run_explainability( def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): - """Uploads the local analysis_config_file to the s3_output_path. + """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. Args: analysis_config_file (str): File path to the local analysis config file. s3_output_path (str): S3 prefix to store the analysis config file. sagemaker_session (:class:`~sagemaker.session.Session`): - Session object which manages interactions with Amazon SageMaker and - any other AWS services needed. If not specified, the processor creates - one using the default AWS configuration chain. + :class:`~sagemaker.session.Session` object which manages interactions with + Amazon SageMaker and any other AWS services needed. If not specified, + the processor creates a :class:`~sagemaker.session.Session` + using the default AWS configuration chain. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). Returns: - The S3 uri of the uploaded file. + The S3 URI of the uploaded file. """ return s3.S3Uploader.upload( local_path=analysis_config_file, From 24bbbdfbe15c233e917c37292e39242141ac5bb6 Mon Sep 17 00:00:00 2001 From: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> Date: Fri, 8 Jul 2022 15:47:37 -0700 Subject: [PATCH 116/526] fix: fix: neo inferentia as compilation target not using framework ver (#3183) --- src/sagemaker/model.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 60c766379b..fa30e4a27c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -17,7 +17,6 @@ import json import logging import os -import re import copy from typing import List, Dict @@ -50,6 +49,8 @@ ["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"] ) +NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"] + class ModelBase(abc.ABC): """An object that encapsulates a trained model. @@ -763,13 +764,22 @@ def _compilation_job_config( "Framework": framework.upper(), } - multiple_version_supported_framework_list = ["pytorch", "tensorflow"] - if ( - framework.lower() in multiple_version_supported_framework_list - and target_instance_type is not None - and re.match("(?=^ml_)(?!ml_inf)", target_instance_type) is not None - and framework_version is not None + def multi_version_compilation_supported( + target_instance_type: str, framework: str, framework_version: str ): + if target_instance_type and framework and framework_version: + framework = framework.lower() + multi_version_frameworks_support_mapping = { + "inferentia": ["pytorch", "tensorflow", "mxnet"], + "neo_ioc_targets": ["pytorch", "tensorflow"], + } + if target_instance_type in NEO_IOC_TARGET_DEVICES: + return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"] + if target_instance_type == "ml_inf": + return framework in multi_version_frameworks_support_mapping["inferentia"] + return False + + if multi_version_compilation_supported(target_instance_type, framework, framework_version): input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version) role = self.sagemaker_session.expand_role(role) From cf061e5dd54f035125452c3dde27b74c377c57aa Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Fri, 8 Jul 2022 18:29:20 -0700 Subject: [PATCH 117/526] feature: Added support for feature group schema change and feature parameters (#3206) Co-authored-by: Mizanur Rahman --- src/sagemaker/feature_store/feature_group.py | 59 ++++++++++++ src/sagemaker/feature_store/inputs.py | 24 +++++ src/sagemaker/session.py | 68 +++++++++++++- tests/integ/test_feature_store.py | 92 ++++++++++++++++++- .../feature_store/test_feature_store.py | 47 ++++++++++ .../sagemaker/feature_store/test_inputs.py | 6 ++ tests/unit/test_session.py | 53 +++++++++++ 7 files changed, 347 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index fcaaec362c..6e6caa6988 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -53,6 +53,7 @@ OfflineStoreConfig, DataCatalogConfig, FeatureValue, + FeatureParameter, ) logger = logging.getLogger(__name__) @@ -543,6 +544,64 @@ def describe(self, next_token: str = None) -> Dict[str, Any]: feature_group_name=self.name, next_token=next_token ) + def update(self, feature_additions: Sequence[FeatureDefinition]) -> Dict[str, Any]: + """Update a FeatureGroup and add new features from the given feature definitions. + + Args: + feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated. + + Returns: + Response dict from service. + """ + + return self.sagemaker_session.update_feature_group( + feature_group_name=self.name, + feature_additions=[ + feature_addition.to_dict() for feature_addition in feature_additions + ], + ) + + def update_feature_metadata( + self, + feature_name: str, + description: str = None, + parameter_additions: Sequence[FeatureParameter] = None, + parameter_removals: Sequence[str] = None, + ) -> Dict[str, Any]: + """Update a feature metadata and add/remove metadata. + + Args: + feature_name (str): name of the feature to update. + description (str): description of the feature to update. + parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added. + parameter_removals (Sequence[str]): list of feature parameter key to be removed. + + Returns: + Response dict from service. + """ + return self.sagemaker_session.update_feature_metadata( + feature_group_name=self.name, + feature_name=feature_name, + description=description, + parameter_additions=[ + parameter_addition.to_dict() for parameter_addition in (parameter_additions or []) + ], + parameter_removals=(parameter_removals or []), + ) + + def describe_feature_metadata(self, feature_name: str) -> Dict[str, Any]: + """Describe feature metadata by feature name. + + Args: + feature_name (str): name of the feature. + Returns: + Response dict from service. + """ + + return self.sagemaker_session.describe_feature_metadata( + feature_group_name=self.name, feature_name=feature_name + ) + def load_feature_definitions( self, data_frame: DataFrame, diff --git a/src/sagemaker/feature_store/inputs.py b/src/sagemaker/feature_store/inputs.py index 1f31caa4ae..75cb99b5f6 100644 --- a/src/sagemaker/feature_store/inputs.py +++ b/src/sagemaker/feature_store/inputs.py @@ -207,3 +207,27 @@ def to_dict(self) -> Dict[str, Any]: FeatureName=self.feature_name, ValueAsString=self.value_as_string, ) + + +@attr.s +class FeatureParameter(Config): + """FeatureParameter for FeatureStore. + + Attributes: + key (str): key of the parameter. + value (str): value of the parameter. + """ + + key: str = attr.ib(default=None) + value: str = attr.ib(default=None) + + def to_dict(self) -> Dict[str, Any]: + """Construct a dictionary based on the attributes provided. + + Returns: + dict represents the attributes. + """ + return Config.construct_dict( + Key=self.key, + Value=self.value, + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index eb158eab3d..f426724b6c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4085,7 +4085,7 @@ def describe_feature_group( """Describe a FeatureGroup by name in FeatureStore service. Args: - feature_group_name (str): name of the FeatureGroup to descibe. + feature_group_name (str): name of the FeatureGroup to describe. next_token (str): next_token to get next page of features. Returns: Response dict from service. @@ -4095,6 +4095,72 @@ def describe_feature_group( update_args(kwargs, NextToken=next_token) return self.sagemaker_client.describe_feature_group(**kwargs) + def update_feature_group( + self, feature_group_name: str, feature_additions: Sequence[Dict[str, str]] + ) -> Dict[str, Any]: + """Update a FeatureGroup and add new features from the given feature definitions. + + Args: + feature_group_name (str): name of the FeatureGroup to update. + feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated. + Returns: + Response dict from service. + """ + + return self.sagemaker_client.update_feature_group( + FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions + ) + + def update_feature_metadata( + self, + feature_group_name: str, + feature_name: str, + description: str = None, + parameter_additions: Sequence[Dict[str, str]] = None, + parameter_removals: Sequence[str] = None, + ) -> Dict[str, Any]: + """Update a feature metadata and add/remove metadata. + + Args: + feature_group_name (str): name of the FeatureGroup to update. + feature_name (str): name of the feature to update. + description (str): description of the feature to update. + parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added. + parameter_removals (Sequence[Dict[str, str]): list of feature parameter to be removed. + Returns: + Response dict from service. + """ + + request = { + "FeatureGroupName": feature_group_name, + "FeatureName": feature_name, + } + + if description is not None: + request["Description"] = description + if parameter_additions is not None: + request["ParameterAdditions"] = parameter_additions + if parameter_removals is not None: + request["ParameterRemovals"] = parameter_removals + + return self.sagemaker_client.update_feature_metadata(**request) + + def describe_feature_metadata( + self, feature_group_name: str, feature_name: str + ) -> Dict[str, Any]: + """Describe feature metadata by feature name in FeatureStore service. + + Args: + feature_group_name (str): name of the FeatureGroup. + feature_name (str): name of the feature. + Returns: + Response dict from service. + """ + + return self.sagemaker_client.describe_feature_metadata( + FeatureGroupName=feature_group_name, FeatureName=feature_name + ) + def put_record( self, feature_group_name: str, diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index 15c1db41ab..73f6cc9104 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -22,8 +22,9 @@ import pytest from pandas import DataFrame +from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup -from sagemaker.feature_store.inputs import FeatureValue +from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout @@ -237,6 +238,83 @@ def test_create_feature_store( assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") +def test_update_feature_group( + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + new_feature_name = "new_feature" + new_features = [FractionalFeatureDefinition(feature_name=new_feature_name)] + feature_group.update(new_features) + _wait_for_feature_group_update(feature_group) + feature_definitions = feature_group.describe().get("FeatureDefinitions") + assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) + + +def test_feature_metadata( + feature_store_session, + role, + feature_group_name, + offline_store_s3_uri, + pandas_data_frame, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + parameter_additions = [ + FeatureParameter(key="key1", value="value1"), + FeatureParameter(key="key2", value="value2"), + ] + description = "test description" + feature_name = "feature1" + feature_group.update_feature_metadata( + feature_name=feature_name, + description=description, + parameter_additions=parameter_additions, + ) + describe_feature_metadata = feature_group.describe_feature_metadata( + feature_name=feature_name + ) + print(describe_feature_metadata) + assert description == describe_feature_metadata.get("Description") + assert 2 == len(describe_feature_metadata.get("Parameters")) + + parameter_removals = ["key1"] + feature_group.update_feature_metadata( + feature_name=feature_name, parameter_removals=parameter_removals + ) + describe_feature_metadata = feature_group.describe_feature_metadata( + feature_name=feature_name + ) + assert description == describe_feature_metadata.get("Description") + assert 1 == len(describe_feature_metadata.get("Parameters")) + + def test_ingest_without_string_feature( feature_store_session, role, @@ -304,6 +382,18 @@ def _wait_for_feature_group_create(feature_group: FeatureGroup): print(f"FeatureGroup {feature_group.name} successfully created.") +def _wait_for_feature_group_update(feature_group: FeatureGroup): + status = feature_group.describe().get("LastUpdateStatus").get("Status") + while status == "InProgress": + print("Waiting for Feature Group Update") + time.sleep(5) + status = feature_group.describe().get("LastUpdateStatus").get("Status") + if status != "Successful": + print(feature_group.describe()) + raise RuntimeError(f"Failed to update feature group {feature_group.name}") + print(f"FeatureGroup {feature_group.name} successfully updated.") + + @contextmanager def cleanup_feature_group(feature_group: FeatureGroup): try: diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index ef6a36980b..8f2f0eb3f9 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -31,6 +31,7 @@ AthenaQuery, IngestionError, ) +from sagemaker.feature_store.inputs import FeatureParameter class PicklableMock(Mock): @@ -154,6 +155,52 @@ def test_feature_store_describe(sagemaker_session_mock): ) +def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(feature_group_dummy_definitions) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + ) + + +def test_feature_metadata_update(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + parameter_additions = [FeatureParameter(key="key1", value="value1")] + parameter_removals = ["key2"] + + feature_group.update_feature_metadata( + feature_name="Feature1", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[pa.to_dict() for pa in parameter_additions], + parameter_removals=parameter_removals, + ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[], + parameter_removals=[], + ) + + +def test_feature_metadata_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe_feature_metadata(feature_name="Feature1") + sagemaker_session_mock.describe_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", feature_name="Feature1" + ) + + def test_put_record(sagemaker_session_mock): feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) feature_group.put_record(record=[]) diff --git a/tests/unit/sagemaker/feature_store/test_inputs.py b/tests/unit/sagemaker/feature_store/test_inputs.py index d111cc0c00..322a049309 100644 --- a/tests/unit/sagemaker/feature_store/test_inputs.py +++ b/tests/unit/sagemaker/feature_store/test_inputs.py @@ -19,6 +19,7 @@ S3StorageConfig, DataCatalogConfig, OfflineStoreConfig, + FeatureParameter, ) @@ -83,3 +84,8 @@ def test_offline_data_store_config(): "DisableGlueTableCreation": False, } ) + + +def test_feature_metadata(): + config = FeatureParameter(key="key", value="value") + assert ordered(config.to_dict()) == ordered({"Key": "key", "Value": "value"}) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index a02ea6eeca..1fd58ea531 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2535,6 +2535,59 @@ def test_feature_group_describe(sagemaker_session): ) +def test_feature_group_update(sagemaker_session, feature_group_dummy_definitions): + sagemaker_session.update_feature_group( + feature_group_name="MyFeatureGroup", + feature_additions=feature_group_dummy_definitions, + ) + assert sagemaker_session.sagemaker_client.update_feature_group.called_with( + FeatureGroupName="MyFeatureGroup", + FeatureAdditions=feature_group_dummy_definitions, + ) + + +def test_feature_metadata_update(sagemaker_session): + parameter_additions = [ + { + "key": "TestKey", + "value": "TestValue", + } + ] + parameter_removals = ["TestKey"] + + sagemaker_session.update_feature_metadata( + feature_group_name="TestFeatureGroup", + feature_name="TestFeature", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + assert sagemaker_session.sagemaker_client.update_feature_group.called_with( + feature_group_name="TestFeatureGroup", + FeatureName="TestFeature", + Description="TestDescription", + ParameterAdditions=parameter_additions, + ParameterRemovals=parameter_removals, + ) + sagemaker_session.update_feature_metadata( + feature_group_name="TestFeatureGroup", + feature_name="TestFeature", + ) + assert sagemaker_session.sagemaker_client.update_feature_group.called_with( + feature_group_name="TestFeatureGroup", + FeatureName="TestFeature", + ) + + +def test_feature_metadata_describe(sagemaker_session): + sagemaker_session.describe_feature_metadata( + feature_group_name="MyFeatureGroup", feature_name="TestFeature" + ) + assert sagemaker_session.sagemaker_client.describe_feature_metadata.called_with( + FeatureGroupName="MyFeatureGroup", FeatureName="TestFeature" + ) + + def test_start_query_execution(sagemaker_session): athena_mock = Mock() sagemaker_session.boto_session.client( From e12381ee8059cf24d78f235aa382a28d7b4f16d7 Mon Sep 17 00:00:00 2001 From: jerrypeng7773 <50377760+jerrypeng7773@users.noreply.github.com> Date: Mon, 11 Jul 2022 15:32:41 -0700 Subject: [PATCH 118/526] fix: Fix processing image uri param (#3158) --- src/sagemaker/estimator.py | 9 +++- src/sagemaker/model.py | 2 +- src/sagemaker/processing.py | 6 ++- src/sagemaker/transformer.py | 4 +- src/sagemaker/tuner.py | 4 +- src/sagemaker/utils.py | 18 +++++-- src/sagemaker/workflow/__init__.py | 12 +++++ src/sagemaker/workflow/airflow.py | 7 ++- tests/unit/sagemaker/model/test_model.py | 4 +- .../workflow/test_pipeline_session.py | 50 ++++++++++++++++++- .../workflow/test_processing_step.py | 16 ++++-- tests/unit/test_utils.py | 42 ++++++++++++++++ 12 files changed, 156 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9620ab7408..b6fd68f472 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -105,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options" SM_DDP_CUSTOM_MPI_OPTIONS = "sagemaker_distributed_dataparallel_custom_mpi_options" CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" + JOB_CLASS_NAME = "training-job" def __init__( self, @@ -594,7 +595,9 @@ def _ensure_base_job_name(self): self.base_job_name = ( self.base_job_name or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri) - or base_name_from_image(self.training_image_uri()) + or base_name_from_image( + self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME + ) ) def _get_or_create_name(self, name=None): @@ -1007,7 +1010,9 @@ def fit( def _compilation_job_name(self): """Placeholder docstring""" - base_name = self.base_job_name or base_name_from_image(self.training_image_uri()) + base_name = self.base_job_name or base_name_from_image( + self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME + ) return name_from_base("compilation-" + base_name) def compile_model( diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index fa30e4a27c..8f128fe3f4 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -679,7 +679,7 @@ def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri): self._base_name = ( self._base_name or get_jumpstart_base_name_if_jumpstart_model(script_uri, model_uri) - or utils.base_name_from_image(image_uri) + or utils.base_name_from_image(image_uri, default_base_name=Model.__name__) ) def _set_model_name_if_needed(self): diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 1e4cfae4ff..8da6e04768 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -47,6 +47,8 @@ class Processor(object): """Handles Amazon SageMaker Processing tasks.""" + JOB_CLASS_NAME = "processing-job" + def __init__( self, role: str, @@ -282,7 +284,9 @@ def _generate_current_job_name(self, job_name=None): if self.base_job_name: base_name = self.base_job_name else: - base_name = base_name_from_image(self.image_uri) + base_name = base_name_from_image( + self.image_uri, default_base_name=Processor.JOB_CLASS_NAME + ) return name_from_base(base_name) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 7bd2f09063..dbe54c8d57 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -28,6 +28,8 @@ class Transformer(object): """A class for handling creating and interacting with Amazon SageMaker transform jobs.""" + JOB_CLASS_NAME = "transform-job" + def __init__( self, model_name: Union[str, PipelineVariable], @@ -243,7 +245,7 @@ def _retrieve_base_name(self): image_uri = self._retrieve_image_uri() if image_uri: - return base_name_from_image(image_uri) + return base_name_from_image(image_uri, default_base_name=Transformer.JOB_CLASS_NAME) return self.model_name diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 76337b8b4f..58c875f8d9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -373,7 +373,9 @@ def _prepare_job_name_for_tuning(self, job_name=None): estimator = ( self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]] ) - base_name = base_name_from_image(estimator.training_image_uri()) + base_name = base_name_from_image( + estimator.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME + ) jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model( getattr(estimator, "source_dir", None), diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index ed5b3c5e75..2bcfab1bd5 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -33,6 +33,7 @@ from sagemaker import deprecations from sagemaker.session_settings import SessionSettings +from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string ECR_URI_PATTERN = r"^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(.*)(/)(.*:.*)$" @@ -90,18 +91,27 @@ def unique_name_from_base(base, max_length=63): return "{}-{}-{}".format(trimmed, ts, unique) -def base_name_from_image(image): +def base_name_from_image(image, default_base_name=None): """Extract the base name of the image to use as the 'algorithm name' for the job. Args: image (str): Image name. + default_base_name (str): The default base name Returns: str: Algorithm name, as extracted from the image name. """ - m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image) - algo_name = m.group(2) if m else image - return algo_name + if is_pipeline_variable(image): + if is_pipeline_parameter_string(image) and image.default_value: + image_str = image.default_value + else: + return default_base_name if default_base_name else "base_name" + else: + image_str = image + + m = re.match("^(.+/)?([^:/]+)(:[^:]+)?$", image_str) + base_name = m.group(2) if m else image_str + return base_name def base_from_name(name): diff --git a/src/sagemaker/workflow/__init__.py b/src/sagemaker/workflow/__init__.py index b4d9e53808..a6961be164 100644 --- a/src/sagemaker/workflow/__init__.py +++ b/src/sagemaker/workflow/__init__.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from sagemaker.workflow.entities import Expression +from sagemaker.workflow.parameters import ParameterString def is_pipeline_variable(var: object) -> bool: @@ -29,3 +30,14 @@ def is_pipeline_variable(var: object) -> bool: # as well as PipelineExperimentConfigProperty and PropertyFile # TODO: We should deprecate the Expression and replace it with PipelineVariable return isinstance(var, Expression) + + +def is_pipeline_parameter_string(var: object) -> bool: + """Check if the variable is a pipeline parameter string + + Args: + var (object): The variable to be verified. + Returns: + bool: True if it is, False otherwise. + """ + return isinstance(var, ParameterString) diff --git a/src/sagemaker/workflow/airflow.py b/src/sagemaker/workflow/airflow.py index 7c78543702..a3565ba9c1 100644 --- a/src/sagemaker/workflow/airflow.py +++ b/src/sagemaker/workflow/airflow.py @@ -20,6 +20,8 @@ from sagemaker import fw_utils, job, utils, s3, session, vpc_utils from sagemaker.amazon import amazon_estimator from sagemaker.tensorflow import TensorFlow +from sagemaker.estimator import EstimatorBase +from sagemaker.processing import Processor def prepare_framework(estimator, s3_operations): @@ -151,7 +153,8 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size= estimator._current_job_name = job_name else: base_name = estimator.base_job_name or utils.base_name_from_image( - estimator.training_image_uri() + estimator.training_image_uri(), + default_base_name=EstimatorBase.JOB_CLASS_NAME, ) estimator._current_job_name = utils.name_from_base(base_name) @@ -1138,7 +1141,7 @@ def processing_config( processor._current_job_name = ( utils.name_from_base(base_name) if base_name is not None - else utils.base_name_from_image(processor.image_uri) + else utils.base_name_from_image(processor.image_uri, Processor.JOB_CLASS_NAME) ) config = { diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index edcbfa7d9f..30f4a20f49 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -287,7 +287,7 @@ def test_create_sagemaker_model_generates_model_name( ) model._create_sagemaker_model(INSTANCE_TYPE) - base_name_from_image.assert_called_with(MODEL_IMAGE) + base_name_from_image.assert_called_with(MODEL_IMAGE, default_base_name="Model") name_from_base.assert_called_with(base_name_from_image.return_value) sagemaker_session.create_model.assert_called_with( @@ -317,7 +317,7 @@ def test_create_sagemaker_model_generates_model_name_each_time( model._create_sagemaker_model(INSTANCE_TYPE) model._create_sagemaker_model(INSTANCE_TYPE) - base_name_from_image.assert_called_once_with(MODEL_IMAGE) + base_name_from_image.assert_called_once_with(MODEL_IMAGE, default_base_name="Model") name_from_base.assert_called_with(base_name_from_image.return_value) assert 2 == name_from_base.call_count diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index 90a9116c07..96e4032a74 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -17,8 +17,16 @@ from mock import Mock, PropertyMock from sagemaker import Model -from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow import is_pipeline_variable, is_pipeline_parameter_string +from sagemaker.workflow.parameters import ( + ParameterString, + ParameterInteger, + ParameterBoolean, + ParameterFloat, +) +from sagemaker.workflow.functions import Join, JsonGet +from tests.unit.sagemaker.workflow.helpers import CustomStep from botocore.config import Config @@ -130,6 +138,46 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock): assert len(register_step_args.need_runtime_repack) == 0 +@pytest.mark.parametrize( + "item", + [ + (ParameterString(name="my-str"), True), + (ParameterBoolean(name="my-bool"), True), + (ParameterFloat(name="my-float"), True), + (ParameterInteger(name="my-int"), True), + (Join(on="/", values=["my", "value"]), True), + (JsonGet(step_name="my-step", property_file="pf", json_path="path"), True), + (CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, True), + ("my-str", False), + (1, False), + (CustomStep(name="my-ste"), False), + ], +) +def test_is_pipeline_variable(item): + var, assertion = item + assert is_pipeline_variable(var) == assertion + + +@pytest.mark.parametrize( + "item", + [ + (ParameterString(name="my-str"), True), + (ParameterBoolean(name="my-bool"), False), + (ParameterFloat(name="my-float"), False), + (ParameterInteger(name="my-int"), False), + (Join(on="/", values=["my", "value"]), False), + (JsonGet(step_name="my-step", property_file="pf", json_path="path"), False), + (CustomStep(name="my-step").properties.OutputDataConfig.S3OutputPath, False), + ("my-str", False), + (1, False), + (CustomStep(name="my-ste"), False), + ], +) +def test_is_pipeline_parameter_string(item): + var, assertion = item + assert is_pipeline_parameter_string(var) == assertion + + def test_pipeline_session_context_for_model_step_without_instance_types( pipeline_session_mock, ): diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index e1b02c17d4..f2347bbf11 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -336,9 +336,20 @@ def test_processing_step_with_processor(pipeline_session, processing_input): ) -def test_processing_step_with_processor_and_step_args(pipeline_session, processing_input): +@pytest.mark.parametrize( + "image_uri", + [ + IMAGE_URI, + ParameterString(name="MyImage"), + ParameterString(name="MyImage", default_value="my-image-uri"), + Join(on="/", values=["docker", "my-fake-image"]), + ], +) +def test_processing_step_with_processor_and_step_args( + pipeline_session, processing_input, image_uri +): processor = Processor( - image_uri=IMAGE_URI, + image_uri=image_uri, role=ROLE, instance_count=1, instance_type=INSTANCE_TYPE, @@ -346,7 +357,6 @@ def test_processing_step_with_processor_and_step_args(pipeline_session, processi ) step_args = processor.run(inputs=processing_input) - try: ProcessingStep( name="MyProcessingStep", diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 893935542f..4e6ba92730 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -29,6 +29,8 @@ import sagemaker from sagemaker.session_settings import SessionSettings +from tests.unit.sagemaker.workflow.helpers import CustomStep +from sagemaker.workflow.parameters import ParameterString BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission" @@ -82,6 +84,46 @@ def test_name_from_image(base_name_from_image, name_from_base): name_from_base.assert_called_with(base_name_from_image.return_value, max_length=max_length) +@pytest.mark.parametrize( + "inputs", + [ + ( + CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath, + None, + "base_name", + ), + ( + CustomStep(name="test-custom-step").properties.OutputDataConfig.S3OutputPath, + "whatever", + "whatever", + ), + (ParameterString(name="image_uri"), None, "base_name"), + (ParameterString(name="image_uri"), "whatever", "whatever"), + ( + ParameterString( + name="image_uri", + default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer", + ), + None, + "analyzer", + ), + ( + ParameterString( + name="image_uri", + default_value="922956235488.dkr.ecr.us-west-2.amazonaws.com/analyzer", + ), + "whatever", + "analyzer", + ), + ], +) +def test_base_name_from_image_with_pipeline_param(inputs): + image, default_base_name, expected = inputs + assert expected == sagemaker.utils.base_name_from_image( + image=image, default_base_name=default_base_name + ) + + @patch("sagemaker.utils.sagemaker_timestamp") def test_name_from_base(sagemaker_timestamp): sagemaker.utils.name_from_base(NAME, short=False) From 8cf8fa021b375e63f8babe4d7116541de38fc20f Mon Sep 17 00:00:00 2001 From: keerthanvasist Date: Mon, 11 Jul 2022 15:33:07 -0700 Subject: [PATCH 119/526] feat: Add target_model to support multi-model endpoints (#3215) --- src/sagemaker/clarify.py | 5 +++++ tests/unit/test_clarify.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index eaf78069c3..24fe1f0a48 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -200,6 +200,7 @@ def __init__( custom_attributes=None, accelerator_type=None, endpoint_name_prefix=None, + target_model=None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. @@ -234,6 +235,9 @@ def __init__( for making inferences to the model. endpoint_name_prefix (str): The endpoint name prefix of a new endpoint. Must follow pattern ``^[a-zA-Z0-9](-\*[a-zA-Z0-9]``. + target_model (str): Sets the target model name when using a multi-model endpoint. For + more information about multi-model endpoints, see + https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html Raises: ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid, @@ -281,6 +285,7 @@ def __init__( self.predictor_config["content_template"] = content_template _set(custom_attributes, "custom_attributes", self.predictor_config) _set(accelerator_type, "accelerator_type", self.predictor_config) + _set(target_model, "target_model", self.predictor_config) def get_predictor_config(self): """Returns part of the predictor dictionary of the analysis config.""" diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 0a1d90d74c..1e3ae47f63 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -240,6 +240,7 @@ def test_model_config(): accept_type = "text/csv" content_type = "application/jsonlines" custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4" + target_model = "target_model_name" accelerator_type = "ml.eia1.medium" model_config = ModelConfig( model_name=model_name, @@ -249,6 +250,7 @@ def test_model_config(): content_type=content_type, custom_attributes=custom_attributes, accelerator_type=accelerator_type, + target_model=target_model, ) expected_config = { "model_name": model_name, @@ -258,6 +260,7 @@ def test_model_config(): "content_type": content_type, "custom_attributes": custom_attributes, "accelerator_type": accelerator_type, + "target_model": target_model, } assert expected_config == model_config.get_predictor_config() From d94cf5ee15970811a75c3fcc691b122ac0f8bd90 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh <105655261+rahven14@users.noreply.github.com> Date: Tue, 12 Jul 2022 22:44:56 +0530 Subject: [PATCH 120/526] fix: make 'ModelInput' field optional for inference recommendation (#3220) * fix: make 'ModelInput' field optional for inference recommendation * fix: refactor code to conditionally update container object --- src/sagemaker/utils.py | 71 ++++++++++++------- .../test_model_create_and_registration.py | 6 -- .../workflow/test_pipeline_session.py | 4 -- .../workflow/test_step_collections.py | 9 --- tests/unit/test_estimator.py | 4 -- 5 files changed, 47 insertions(+), 47 deletions(-) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 2bcfab1bd5..65de8981aa 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -759,32 +759,55 @@ def update_container_with_inference_params( dict: dict with inference recommender params """ - if ( - framework is not None - and framework_version is not None - and nearest_model_name is not None - and data_input_configuration is not None - ): + if framework is not None and framework_version is not None and nearest_model_name is not None: if container_list is not None: for obj in container_list: - obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + construct_container_object( + obj, data_input_configuration, framework, framework_version, nearest_model_name ) if container_obj is not None: - container_obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - "ModelInput": { - "DataInputConfig": data_input_configuration, - }, - } + construct_container_object( + container_obj, + data_input_configuration, + framework, + framework_version, + nearest_model_name, ) + + +def construct_container_object( + obj, data_input_configuration, framework, framework_version, nearest_model_name +): + """Function to construct container object. + + Args: + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + container_obj (dict): object to be updated. + container_list (list): list to be updated. + + Returns: + dict: container object + """ + + obj.update( + { + "Framework": framework, + "FrameworkVersion": framework_version, + "NearestModelName": nearest_model_name, + } + ) + + if data_input_configuration is not None: + obj.update( + { + "ModelInput": { + "DataInputConfig": data_input_configuration, + }, + } + ) diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index d0f617a266..56611fb696 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -99,7 +99,6 @@ def test_conditional_pytorch_training_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) @@ -132,7 +131,6 @@ def test_conditional_pytorch_training_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) model = Model( @@ -219,7 +217,6 @@ def test_mxnet_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' model = MXNetModel( entry_point=entry_point, @@ -244,7 +241,6 @@ def test_mxnet_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) pipeline = Pipeline( @@ -293,7 +289,6 @@ def test_sklearn_xgboost_sip_model_registration( framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_configuration = '{"input_1":[1,224,224,3]}' # The instance_type should not be a pipeline variable # since it is used to retrieve image_uri in compile time (PySDK) @@ -450,7 +445,6 @@ def test_sklearn_xgboost_sip_model_registration( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_configuration, ) pipeline = Pipeline( diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index 96e4032a74..13af00cf6a 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -199,7 +199,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) expected_output = { @@ -221,9 +220,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, } ], "SupportedContentTypes": ["text/csv"], diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 4aa55fd068..fd84bf4b77 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -446,7 +446,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -523,7 +522,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) assert ordered(register_model.request_dicts()) == ordered( [ @@ -542,9 +540,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, }, { "Image": "fakeimage2", @@ -553,9 +548,6 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): "Framework": "TENSORFLOW", "FrameworkVersion": "2.9", "NearestModelName": "resnet50", - "ModelInput": { - "DataInputConfig": '{"input_1":[1,224,224,3]}', - }, }, ], "SupportedContentTypes": ["content_type"], @@ -619,7 +611,6 @@ def test_register_model_with_model_repack_with_estimator( framework="TENSORFLOW", framework_version="2.9", nearest_model_name="resnet50", - data_input_configuration='{"input_1":[1,224,224,3]}', ) request_dicts = register_model.request_dicts() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index d402a509fc..859cdb941f 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -3260,7 +3260,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3271,7 +3270,6 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() @@ -3319,7 +3317,6 @@ def test_register_inference_image(sagemaker_session): framework = "TENSORFLOW" framework_version = "2.9" nearest_model_name = "resnet50" - data_input_config = '{"input_1":[1,224,224,3]}' estimator.register( content_types=content_types, @@ -3333,7 +3330,6 @@ def test_register_inference_image(sagemaker_session): framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, - data_input_configuration=data_input_config, ) sagemaker_session.create_model.assert_not_called() From e91845490aa236f4ee782da76b506f7c1edbd486 Mon Sep 17 00:00:00 2001 From: jerrypeng7773 <50377760+jerrypeng7773@users.noreply.github.com> Date: Tue, 12 Jul 2022 12:23:52 -0700 Subject: [PATCH 121/526] fix: support pipeline variables for spark processors run arguments (#3167) --- src/sagemaker/spark/processing.py | 153 +++++++----- .../workflow/test_processing_step.py | 227 ++++++++++++++++-- 2 files changed, 293 insertions(+), 87 deletions(-) diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index aab9279f78..90f6a3d8ae 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -31,13 +31,20 @@ from io import BytesIO from urllib.parse import urlparse +from typing import Union, List, Dict, Optional + from sagemaker import image_uris from sagemaker.local.image import _ecr_login_if_needed, _pull_image from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor from sagemaker.s3 import S3Uploader from sagemaker.session import Session +from sagemaker.network import NetworkConfig from sagemaker.spark import defaults +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.functions import Join + logger = logging.getLogger(__name__) @@ -249,6 +256,12 @@ def run( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) + if is_pipeline_variable(submit_app): + raise ValueError( + "submit_app argument has to be a valid S3 URI or local file path " + + "rather than a pipeline variable" + ) + return super().run( submit_app, inputs, @@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): use_input_channel = False spark_opt_s3_uris = [] + spark_opt_s3_uris_has_pipeline_var = False with tempfile.TemporaryDirectory() as tmpdir: for dep_path in submit_deps: + if is_pipeline_variable(dep_path): + spark_opt_s3_uris.append(dep_path) + spark_opt_s3_uris_has_pipeline_var = True + continue dep_url = urlparse(dep_path) # S3 URIs are included as-is in the spark-submit argument if dep_url.scheme in ["s3", "s3a"]: @@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name): destination=f"{self._conf_container_base_path}{input_channel_name}", input_name=input_channel_name, ) - spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination]) + spark_opt = ( + Join(on=",", values=spark_opt_s3_uris + [input_channel.destination]) + if spark_opt_s3_uris_has_pipeline_var + else ",".join(spark_opt_s3_uris + [input_channel.destination]) + ) # If no local files were uploaded, form the spark-submit option from a list of S3 URIs else: input_channel = None - spark_opt = ",".join(spark_opt_s3_uris) + spark_opt = ( + Join(on=",", values=spark_opt_s3_uris) + if spark_opt_s3_uris_has_pipeline_var + else ",".join(spark_opt_s3_uris) + ) return input_channel, spark_opt @@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path): Args: spark_output_s3_path (str): The URI of the Spark output S3 Path. """ + if is_pipeline_variable(spark_output_s3_path): + return + if urlparse(spark_output_s3_path).scheme != "s3": raise ValueError( f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " @@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase): def __init__( self, - role, - instance_type, - instance_count, - framework_version=None, - py_version=None, - container_version=None, - image_uri=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_type: Union[int, PipelineVariable], + instance_count: Union[str, PipelineVariable], + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + container_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize an ``PySparkProcessor`` instance. @@ -795,20 +824,20 @@ def get_run_args( def run( self, - submit_app, - submit_py_files=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - configuration=None, - spark_event_logs_s3_uri=None, - kms_key=None, + submit_app: str, + submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None, + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + configuration: Optional[Union[List[Dict], Dict]] = None, + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase): def __init__( self, - role, - instance_type, - instance_count, - framework_version=None, - py_version=None, - container_version=None, - image_uri=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + instance_type: Union[int, PipelineVariable], + instance_count: Union[str, PipelineVariable], + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + container_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initialize a ``SparkJarProcessor`` instance. @@ -1052,20 +1081,20 @@ def get_run_args( def run( self, - submit_app, - submit_class=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - configuration=None, - spark_event_logs_s3_uri=None, - kms_key=None, + submit_app: str, + submit_class: Union[str, PipelineVariable], + submit_jars: Optional[List[Union[str, PipelineVariable]]] = None, + submit_files: Optional[List[Union[str, PipelineVariable]]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + configuration: Optional[Union[List[Dict], Dict]] = None, + spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index f2347bbf11..262d0eb558 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -46,6 +46,7 @@ from sagemaker.workflow.properties import PropertyFile from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join +from sagemaker.workflow import is_pipeline_variable from sagemaker.network import NetworkConfig from sagemaker.pytorch.estimator import PyTorch @@ -149,31 +150,6 @@ ), {}, ), - ( - SparkJarProcessor( - role=ROLE, - framework_version="2.4", - instance_count=1, - instance_type=INSTANCE_TYPE, - ), - { - "submit_app": "s3://my-jar", - "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", - "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"], - }, - ), - ( - PySparkProcessor( - role=ROLE, - framework_version="2.4", - instance_count=1, - instance_type=INSTANCE_TYPE, - ), - { - "submit_app": "s3://my-jar", - "arguments": ["--input", "input-data-uri", "--output", "output-data-uri"], - }, - ), ] PROCESSING_INPUT = [ @@ -641,3 +617,204 @@ def test_insert_wrong_step_args_into_processing_step(inputs, pipeline_session): assert "The step_args of ProcessingStep must be obtained from processor.run()" in str( error.value ) + + +@pytest.mark.parametrize( + "spark_processor", + [ + ( + SparkJarProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": "s3://my-jar", + "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_jars": [ + "s3://my-jar", + ParameterString("MyJars"), + "s3://her-jar", + ParameterString("OurJar"), + ], + "submit_files": [ + "s3://my-files", + ParameterString("MyFiles"), + "s3://her-files", + ParameterString("OurFiles"), + ], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + }, + ), + ( + PySparkProcessor( + role=ROLE, + framework_version="2.4", + instance_count=1, + instance_type=INSTANCE_TYPE, + ), + { + "submit_app": "s3://my-jar", + "arguments": [ + "--input", + "input-data-uri", + "--output", + ParameterString("MyArgOutput"), + ], + "submit_py_files": [ + "s3://my-py-files", + ParameterString("MyPyFiles"), + "s3://her-pyfiles", + ParameterString("OurPyFiles"), + ], + "submit_jars": [ + "s3://my-jar", + ParameterString("MyJars"), + "s3://her-jar", + ParameterString("OurJar"), + ], + "submit_files": [ + "s3://my-files", + ParameterString("MyFiles"), + "s3://her-files", + ParameterString("OurFiles"), + ], + "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"), + }, + ), + ], +) +def test_spark_processor(spark_processor, processing_input, pipeline_session): + + processor, run_inputs = spark_processor + processor.sagemaker_session = pipeline_session + processor.role = ROLE + + run_inputs["inputs"] = processing_input + + step_args = processor.run(**run_inputs) + step = ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + ) + + step_args = step_args.args + + assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + + entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] + entry_points_expr = [] + for entry_point in entry_points: + if is_pipeline_variable(entry_point): + entry_points_expr.append(entry_point.expr) + else: + entry_points_expr.append(entry_point) + + if "submit_py_files" in run_inputs: + expected = [ + "smspark-submit", + "--py-files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-py-files", + {"Get": "Parameters.MyPyFiles"}, + "s3://her-pyfiles", + {"Get": "Parameters.OurPyFiles"}, + ], + } + }, + "--jars", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-jar", + {"Get": "Parameters.MyJars"}, + "s3://her-jar", + {"Get": "Parameters.OurJar"}, + ], + } + }, + "--files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-files", + {"Get": "Parameters.MyFiles"}, + "s3://her-files", + {"Get": "Parameters.OurFiles"}, + ], + } + }, + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code", + ] + # py spark + else: + expected = [ + "smspark-submit", + "--class", + "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp", + "--jars", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-jar", + {"Get": "Parameters.MyJars"}, + "s3://her-jar", + {"Get": "Parameters.OurJar"}, + ], + } + }, + "--files", + { + "Std:Join": { + "On": ",", + "Values": [ + "s3://my-files", + {"Get": "Parameters.MyFiles"}, + "s3://her-files", + {"Get": "Parameters.OurFiles"}, + ], + } + }, + "--local-spark-event-logs-dir", + "/opt/ml/processing/spark-events/", + "/opt/ml/processing/input/code", + ] + + assert entry_points_expr == expected + for output in step_args["ProcessingOutputConfig"]["Outputs"]: + if is_pipeline_variable(output["S3Output"]["S3Uri"]): + output["S3Output"]["S3Uri"] = output["S3Output"]["S3Uri"].expr + + assert step_args["ProcessingOutputConfig"]["Outputs"] == [ + { + "OutputName": "output-1", + "AppManaged": False, + "S3Output": { + "S3Uri": {"Get": "Parameters.MySparkEventLogS3Uri"}, + "LocalPath": "/opt/ml/processing/spark-events/", + "S3UploadMode": "Continuous", + }, + } + ] + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + pipeline.definition() From 31c5d65b8964349d93499d1405c1464b1261556a Mon Sep 17 00:00:00 2001 From: Ao Guo <72373287+aoguo64@users.noreply.github.com> Date: Tue, 12 Jul 2022 16:59:27 -0700 Subject: [PATCH 122/526] change: remove primitive_or_expr() from conditions (#3212) Co-authored-by: Ao Guo --- src/sagemaker/workflow/condition_step.py | 36 +--- src/sagemaker/workflow/conditions.py | 28 +-- .../sagemaker/workflow/test_condition_step.py | 159 +++++++++++++++++- .../sagemaker/workflow/test_conditions.py | 32 ++-- 4 files changed, 183 insertions(+), 72 deletions(-) diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index e5797b5b63..624c42fa66 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -15,20 +15,17 @@ from typing import List, Union, Optional -import attr from sagemaker.deprecations import deprecated_class from sagemaker.workflow.conditions import Condition from sagemaker.workflow.step_collections import StepCollection +from sagemaker.workflow.functions import JsonGet as NewJsonGet from sagemaker.workflow.steps import ( Step, StepTypeEnum, ) from sagemaker.workflow.utilities import list_to_request -from sagemaker.workflow.entities import ( - RequestType, - PipelineVariable, -) +from sagemaker.workflow.entities import RequestType from sagemaker.workflow.properties import ( Properties, PropertyFile, @@ -93,7 +90,7 @@ def arguments(self) -> RequestType: @property def step_only_arguments(self): """Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`.""" - return self.conditions + return [condition.to_request() for condition in self.conditions] @property def properties(self): @@ -101,8 +98,7 @@ def properties(self): return self._properties -@attr.s -class JsonGet(PipelineVariable): # pragma: no cover +class JsonGet(NewJsonGet): # pragma: no cover """Get JSON properties from PropertyFiles. Attributes: @@ -112,28 +108,8 @@ class JsonGet(PipelineVariable): # pragma: no cover json_path (str): The JSON path expression to the requested value. """ - step: Step = attr.ib() - property_file: Union[PropertyFile, str] = attr.ib() - json_path: str = attr.ib() - - @property - def expr(self): - """The expression dict for a `JsonGet` function.""" - if isinstance(self.property_file, PropertyFile): - name = self.property_file.name - else: - name = self.property_file - return { - "Std:JsonGet": { - "PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"}, - "Path": self.json_path, - } - } - - @property - def _referenced_steps(self) -> List[str]: - """List of step names that this function depends on.""" - return [self.step.name] + def __init__(self, step: Step, property_file: Union[PropertyFile, str], json_path: str): + super().__init__(step_name=step.name, property_file=property_file, json_path=json_path) JsonGet = deprecated_class(JsonGet, "JsonGet") diff --git a/src/sagemaker/workflow/conditions.py b/src/sagemaker/workflow/conditions.py index 67a3ea5396..40d38e7339 100644 --- a/src/sagemaker/workflow/conditions.py +++ b/src/sagemaker/workflow/conditions.py @@ -20,15 +20,13 @@ import abc from enum import Enum -from typing import Dict, List, Union +from typing import List, Union import attr -from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import ( DefaultEnumMeta, Entity, - Expression, PrimitiveType, RequestType, ) @@ -88,8 +86,8 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "Type": self.condition_type.value, - "LeftValue": primitive_or_expr(self.left), - "RightValue": primitive_or_expr(self.right), + "LeftValue": self.left, + "RightValue": self.right, } @property @@ -227,8 +225,8 @@ def to_request(self) -> RequestType: """Get the request structure for workflow service calls.""" return { "Type": self.condition_type.value, - "QueryValue": self.value.expr, - "Values": [primitive_or_expr(in_value) for in_value in self.in_values], + "QueryValue": self.value, + "Values": self.in_values, } @property @@ -291,19 +289,3 @@ def _referenced_steps(self) -> List[str]: for condition in self.conditions: steps.extend(condition._referenced_steps) return steps - - -def primitive_or_expr( - value: Union[ExecutionVariable, Expression, PrimitiveType, Parameter, Properties] -) -> Union[Dict[str, str], PrimitiveType]: - """Provide the expression of the value or return value if it is a primitive. - - Args: - value (Union[ConditionValueType, PrimitiveType]): The value to evaluate. - - Returns: - Either the expression of the value or the primitive value. - """ - if is_pipeline_variable(value): - return value.expr - return value diff --git a/tests/unit/sagemaker/workflow/test_condition_step.py b/tests/unit/sagemaker/workflow/test_condition_step.py index f3d6209f23..6478b35b15 100644 --- a/tests/unit/sagemaker/workflow/test_condition_step.py +++ b/tests/unit/sagemaker/workflow/test_condition_step.py @@ -11,13 +11,25 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import +import json import pytest from mock import Mock, MagicMock -from sagemaker.workflow.conditions import ConditionEquals -from sagemaker.workflow.parameters import ParameterInteger +from sagemaker.workflow.conditions import ( + ConditionEquals, + ConditionGreaterThan, + ConditionGreaterThanOrEqualTo, + ConditionIn, + ConditionLessThan, + ConditionLessThanOrEqualTo, + ConditionNot, + ConditionOr, +) +from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.pipeline import Pipeline, PipelineGraph +from sagemaker.workflow.properties import Properties from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered @@ -56,7 +68,7 @@ def test_condition_step(): "Conditions": [ { "Type": "Equals", - "LeftValue": {"Get": "Parameters.MyInt"}, + "LeftValue": param, "RightValue": 1, }, ], @@ -79,6 +91,147 @@ def test_condition_step(): assert cond_step.properties.Outcome.expr == {"Get": "Steps.MyConditionStep.Outcome"} +def test_pipeline_condition_step_interpolated(sagemaker_session): + param1 = ParameterInteger(name="MyInt1") + param2 = ParameterInteger(name="MyInt2") + param3 = ParameterString(name="MyStr") + var = ExecutionVariables.START_DATETIME + prop = Properties("foo") + + cond_eq = ConditionEquals(left=param1, right=param2) + cond_gt = ConditionGreaterThan(left=var, right="2020-12-01") + cond_gte = ConditionGreaterThanOrEqualTo(left=var, right=param3) + cond_lt = ConditionLessThan(left=var, right="2020-12-01") + cond_lte = ConditionLessThanOrEqualTo(left=var, right=param3) + cond_in = ConditionIn(value=param3, in_values=["abc", "def"]) + cond_in_mixed = ConditionIn(value=param3, in_values=["abc", prop, var]) + cond_not_eq = ConditionNot(expression=cond_eq) + cond_not_in = ConditionNot(expression=cond_in) + cond_or = ConditionOr(conditions=[cond_gt, cond_in]) + + step1 = CustomStep(name="MyStep1") + step2 = CustomStep(name="MyStep2") + cond_step = ConditionStep( + name="MyConditionStep", + conditions=[ + cond_eq, + cond_gt, + cond_gte, + cond_lt, + cond_lte, + cond_in, + cond_in_mixed, + cond_not_eq, + cond_not_in, + cond_or, + ], + if_steps=[step1], + else_steps=[step2], + ) + + pipeline = Pipeline( + name="MyPipeline", + parameters=[param1, param2, param3], + steps=[cond_step], + sagemaker_session=sagemaker_session, + ) + assert json.loads(pipeline.definition()) == { + "Version": "2020-12-01", + "Metadata": {}, + "Parameters": [ + {"Name": "MyInt1", "Type": "Integer"}, + {"Name": "MyInt2", "Type": "Integer"}, + {"Name": "MyStr", "Type": "String"}, + ], + "PipelineExperimentConfig": { + "ExperimentName": {"Get": "Execution.PipelineName"}, + "TrialName": {"Get": "Execution.PipelineExecutionId"}, + }, + "Steps": [ + { + "Name": "MyConditionStep", + "Type": "Condition", + "Arguments": { + "Conditions": [ + { + "Type": "Equals", + "LeftValue": {"Get": "Parameters.MyInt1"}, + "RightValue": {"Get": "Parameters.MyInt2"}, + }, + { + "Type": "GreaterThan", + "LeftValue": {"Get": "Execution.StartDateTime"}, + "RightValue": "2020-12-01", + }, + { + "Type": "GreaterThanOrEqualTo", + "LeftValue": {"Get": "Execution.StartDateTime"}, + "RightValue": {"Get": "Parameters.MyStr"}, + }, + { + "Type": "LessThan", + "LeftValue": {"Get": "Execution.StartDateTime"}, + "RightValue": "2020-12-01", + }, + { + "Type": "LessThanOrEqualTo", + "LeftValue": {"Get": "Execution.StartDateTime"}, + "RightValue": {"Get": "Parameters.MyStr"}, + }, + { + "Type": "In", + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], + }, + { + "Type": "In", + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": [ + "abc", + {"Get": "Steps.foo"}, + {"Get": "Execution.StartDateTime"}, + ], + }, + { + "Type": "Not", + "Expression": { + "Type": "Equals", + "LeftValue": {"Get": "Parameters.MyInt1"}, + "RightValue": {"Get": "Parameters.MyInt2"}, + }, + }, + { + "Type": "Not", + "Expression": { + "Type": "In", + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], + }, + }, + { + "Type": "Or", + "Conditions": [ + { + "Type": "GreaterThan", + "LeftValue": {"Get": "Execution.StartDateTime"}, + "RightValue": "2020-12-01", + }, + { + "Type": "In", + "QueryValue": {"Get": "Parameters.MyStr"}, + "Values": ["abc", "def"], + }, + ], + }, + ], + "IfSteps": [{"Name": "MyStep1", "Type": "Training", "Arguments": {}}], + "ElseSteps": [{"Name": "MyStep2", "Type": "Training", "Arguments": {}}], + }, + } + ], + } + + def test_pipeline(sagemaker_session): param = ParameterInteger(name="MyInt", default_value=2) cond = ConditionEquals(left=param, right=1) diff --git a/tests/unit/sagemaker/workflow/test_conditions.py b/tests/unit/sagemaker/workflow/test_conditions.py index f5afce9de0..a7ec9c0c11 100644 --- a/tests/unit/sagemaker/workflow/test_conditions.py +++ b/tests/unit/sagemaker/workflow/test_conditions.py @@ -36,7 +36,7 @@ def test_condition_equals(): cond = ConditionEquals(left=param, right=1) assert cond.to_request() == { "Type": "Equals", - "LeftValue": {"Get": "Parameters.MyInt"}, + "LeftValue": param, "RightValue": 1, } @@ -47,8 +47,8 @@ def test_condition_equals_parameter(): cond = ConditionEquals(left=param1, right=param2) assert cond.to_request() == { "Type": "Equals", - "LeftValue": {"Get": "Parameters.MyInt1"}, - "RightValue": {"Get": "Parameters.MyInt2"}, + "LeftValue": param1, + "RightValue": param2, } @@ -57,7 +57,7 @@ def test_condition_greater_than(): cond = ConditionGreaterThan(left=var, right="2020-12-01") assert cond.to_request() == { "Type": "GreaterThan", - "LeftValue": {"Get": "Execution.StartDateTime"}, + "LeftValue": var, "RightValue": "2020-12-01", } @@ -68,8 +68,8 @@ def test_condition_greater_than_or_equal_to(): cond = ConditionGreaterThanOrEqualTo(left=var, right=param) assert cond.to_request() == { "Type": "GreaterThanOrEqualTo", - "LeftValue": {"Get": "Execution.StartDateTime"}, - "RightValue": {"Get": "Parameters.StartDateTime"}, + "LeftValue": var, + "RightValue": param, } @@ -78,7 +78,7 @@ def test_condition_less_than(): cond = ConditionLessThan(left=var, right="2020-12-01") assert cond.to_request() == { "Type": "LessThan", - "LeftValue": {"Get": "Execution.StartDateTime"}, + "LeftValue": var, "RightValue": "2020-12-01", } @@ -89,8 +89,8 @@ def test_condition_less_than_or_equal_to(): cond = ConditionLessThanOrEqualTo(left=var, right=param) assert cond.to_request() == { "Type": "LessThanOrEqualTo", - "LeftValue": {"Get": "Execution.StartDateTime"}, - "RightValue": {"Get": "Parameters.StartDateTime"}, + "LeftValue": var, + "RightValue": param, } @@ -99,7 +99,7 @@ def test_condition_in(): cond_in = ConditionIn(value=param, in_values=["abc", "def"]) assert cond_in.to_request() == { "Type": "In", - "QueryValue": {"Get": "Parameters.MyStr"}, + "QueryValue": param, "Values": ["abc", "def"], } @@ -111,8 +111,8 @@ def test_condition_in_mixed(): cond_in = ConditionIn(value=param, in_values=["abc", prop, var]) assert cond_in.to_request() == { "Type": "In", - "QueryValue": {"Get": "Parameters.MyStr"}, - "Values": ["abc", {"Get": "Steps.foo"}, {"Get": "Execution.StartDateTime"}], + "QueryValue": param, + "Values": ["abc", prop, var], } @@ -124,7 +124,7 @@ def test_condition_not(): "Type": "Not", "Expression": { "Type": "Equals", - "LeftValue": {"Get": "Parameters.MyStr"}, + "LeftValue": param, "RightValue": "foo", }, } @@ -138,7 +138,7 @@ def test_condition_not_in(): "Type": "Not", "Expression": { "Type": "In", - "QueryValue": {"Get": "Parameters.MyStr"}, + "QueryValue": param, "Values": ["abc", "def"], }, } @@ -155,12 +155,12 @@ def test_condition_or(): "Conditions": [ { "Type": "GreaterThan", - "LeftValue": {"Get": "Execution.StartDateTime"}, + "LeftValue": var, "RightValue": "2020-12-01", }, { "Type": "In", - "QueryValue": {"Get": "Parameters.MyStr"}, + "QueryValue": param, "Values": ["abc", "def"], }, ], From 623b0b72adbae9f5364cef1fe8ac6988c69d13d8 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Thu, 14 Jul 2022 16:51:23 -0700 Subject: [PATCH 123/526] fix: rename RegisterModel inner steps to prevent duplicate step names (#3240) Co-authored-by: Namrata Madan --- src/sagemaker/workflow/step_collections.py | 9 ++-- .../test_model_create_and_registration.py | 8 ++-- .../workflow/test_step_collections.py | 43 ++++++++++++------- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index dd9529916e..270b838164 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -57,6 +57,9 @@ def properties(self): class RegisterModel(StepCollection): # pragma: no cover """Register Model step collection for workflow.""" + _REGISTER_MODEL_NAME_BASE = "RegisterModel" + _REPACK_MODEL_NAME_BASE = "RepackModel" + def __init__( self, name: str, @@ -168,7 +171,7 @@ def __init__( kwargs = dict(**kwargs, output_kms_key=kwargs.pop("model_kms_key", None)) repack_model_step = _RepackModelStep( - name=f"{name}RepackModel", + name="{}-{}".format(self.name, self._REPACK_MODEL_NAME_BASE), depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=estimator.sagemaker_session, @@ -212,7 +215,7 @@ def __init__( model_name = model_entity.name or model_entity._framework_name repack_model_step = _RepackModelStep( - name=f"{model_name}RepackModel", + name="{}-{}".format(model_name, self._REPACK_MODEL_NAME_BASE), depends_on=depends_on, retry_policies=repack_model_step_retry_policies, sagemaker_session=sagemaker_session, @@ -256,7 +259,7 @@ def __init__( ) register_model_step = _RegisterModelStep( - name=name, + name="{}-{}".format(self.name, self._REGISTER_MODEL_NAME_BASE), estimator=estimator, model_data=model_data, content_types=content_types, diff --git a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py index 56611fb696..1045a8ef0c 100644 --- a/tests/integ/sagemaker/workflow/test_model_create_and_registration.py +++ b/tests/integ/sagemaker/workflow/test_model_create_and_registration.py @@ -123,8 +123,8 @@ def test_conditional_pytorch_training_model_registration( model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts, content_types=["*"], response_types=["*"], - inference_instances=["*"], - transform_instances=["*"], + inference_instances=["ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], description="test-description", sample_payload_url=sample_payload_url, task=task, @@ -234,7 +234,7 @@ def test_mxnet_model_registration( content_types=["*"], response_types=["*"], inference_instances=["ml.m5.xlarge"], - transform_instances=["*"], + transform_instances=["ml.m5.xlarge"], description="test-description", sample_payload_url=sample_payload_url, task=task, @@ -670,7 +670,7 @@ def test_model_registration_with_drift_check_baselines( ) continue assert execution_steps[0]["StepStatus"] == "Succeeded" - assert execution_steps[0]["StepName"] == "MyRegisterModelStep" + assert execution_steps[0]["StepName"] == "MyRegisterModelStep-RegisterModel" response = sagemaker_session.sagemaker_client.describe_model_package( ModelPackageName=execution_steps[0]["Metadata"]["RegisterModel"]["Arn"] diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index fd84bf4b77..d3b2a19fe3 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -26,7 +26,7 @@ _REPACK_MODEL_NAME_BASE, ) from sagemaker.workflow.parameters import ParameterString -from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.utilities import list_to_request from tests.unit import DATA_DIR @@ -268,7 +268,7 @@ def test_step_collection_properties(pipeline_session, sagemaker_session): steps = register_model.steps assert len(steps) == 1 assert register_model.properties.ModelPackageName.expr == { - "Get": f"Steps.{register_model_step_name}.ModelPackageName" + "Get": f"Steps.{register_model_step_name}-RegisterModel.ModelPackageName" } # Custom StepCollection @@ -330,10 +330,9 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): step_list = json.loads(pipeline.definition())["Steps"] assert len(step_list) == 7 for step in step_list: - if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}RepackModel"]: + if step["Name"] not in ["MyStep2", "MyStep3", f"{model_name}-RepackModel"]: assert "DependsOn" not in step - continue - if step["Name"] == f"{model_name}RepackModel": + elif step["Name"] == f"{model_name}-RepackModel": assert set(step["DependsOn"]) == { "MyStep1", f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", @@ -344,9 +343,21 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): "MyStep1", f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", - f"{model_name}RepackModel", - register_model_name, + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyStep2": [], + "MyStep3": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + "MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyModel-RepackModel": [], + "RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"], + } + ) def test_register_model(estimator, model_metrics, drift_check_baselines): @@ -378,7 +389,7 @@ def test_register_model(estimator, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "DependsOn": ["TestStep"], "DisplayName": "RegisterModelStep", @@ -450,7 +461,7 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "Description": "description", "Arguments": { @@ -526,7 +537,7 @@ def test_register_model_sip(estimator, model_metrics, drift_check_baselines): assert ordered(register_model.request_dicts()) == ordered( [ { - "Name": "RegisterModelStep", + "Name": "RegisterModelStep-RegisterModel", "Type": "RegisterModel", "Description": "description", "DependsOn": ["TestStep"], @@ -618,7 +629,7 @@ def test_register_model_with_model_repack_with_estimator( for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "RegisterModelStepRepackModel" + assert request_dict["Name"] == "RegisterModelStep-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -671,7 +682,7 @@ def test_register_model_with_model_repack_with_estimator( } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1 @@ -745,7 +756,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "modelNameRepackModel" + assert request_dict["Name"] == "modelName-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -798,7 +809,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1 @@ -874,7 +885,7 @@ def test_register_model_with_model_repack_with_pipeline_model( for request_dict in request_dicts: if request_dict["Type"] == "Training": - assert request_dict["Name"] == "modelNameRepackModel" + assert request_dict["Name"] == "modelName-RepackModel" assert len(request_dict["DependsOn"]) == 1 assert request_dict["DependsOn"][0] == "TestStep" arguments = request_dict["Arguments"] @@ -927,7 +938,7 @@ def test_register_model_with_model_repack_with_pipeline_model( } ) elif request_dict["Type"] == "RegisterModel": - assert request_dict["Name"] == "RegisterModelStep" + assert request_dict["Name"] == "RegisterModelStep-RegisterModel" assert "DependsOn" not in request_dict arguments = request_dict["Arguments"] assert len(arguments["InferenceSpecification"]["Containers"]) == 1 From ca597fd06f4cf2747342b4bf3646e0fdbfaa7c31 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Fri, 15 Jul 2022 10:55:03 -0700 Subject: [PATCH 124/526] feature: upgrade to support python 3.10 (#3227) --- requirements/extras/test_requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 2247394441..0fa7b16471 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -12,9 +12,10 @@ awslogs==0.14.0 black==22.3.0 stopit==1.1.2 apache-airflow==2.2.4 -apache-airflow-providers-amazon==3.0.0 +apache-airflow-providers-amazon==4.0.0 attrs==20.3.0 fabric==2.6.0 requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 +pandas>=1.3.5,<1.5 From 0cb83ce8b77f76b5322e04b72d8951579b026472 Mon Sep 17 00:00:00 2001 From: Miyoung Date: Fri, 15 Jul 2022 15:47:54 -0700 Subject: [PATCH 125/526] documentation: SageMaker model parallel library v1.10.0 documentation (#3237) * archive doc for past versions * fix indexing * add new smp cpu memory apis * add new params * add dynamic scale params, add reference * minor fix * minor fixes * rm temp methods * add new checkpoint save/load functions, doc improvement * pass doc8 * Trigger Build * archive doc for past versions * fix indexing * add new smp cpu memory apis * add new params * add dynamic scale params, add reference * minor fix * minor fixes * rm temp methods * add new checkpoint save/load functions, doc improvement * pass doc8 * Trigger Build * remove dist word embedding option Co-authored-by: Shreya Pandit --- .../training/smd_model_parallel_general.rst | 10 + doc/api/training/smp_versions/archives.rst | 1 + doc/api/training/smp_versions/latest.rst | 2 +- .../latest/smd_model_parallel_pytorch.rst | 334 +++++-- ...model_parallel_pytorch_tensor_parallel.rst | 196 ++-- .../v1.9.0/smd_model_parallel_common_api.rst | 538 +++++++++++ .../v1.9.0/smd_model_parallel_pytorch.rst | 677 ++++++++++++++ ...model_parallel_pytorch_tensor_parallel.rst | 876 ++++++++++++++++++ .../v1.9.0/smd_model_parallel_tensorflow.rst | 171 ++++ doc/api/training/smp_versions/v1_9_0.rst | 13 + 10 files changed, 2665 insertions(+), 153 deletions(-) create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst create mode 100644 doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst create mode 100644 doc/api/training/smp_versions/v1_9_0.rst diff --git a/doc/api/training/smd_model_parallel_general.rst b/doc/api/training/smd_model_parallel_general.rst index a35e0d60bc..fbb99f5224 100644 --- a/doc/api/training/smd_model_parallel_general.rst +++ b/doc/api/training/smd_model_parallel_general.rst @@ -178,6 +178,16 @@ PyTorch-specific Parameters - 1 - The number of devices over which the tensor parallel modules will be distributed. If ``tensor_parallel_degree`` is greater than 1, then ``ddp`` must be set to ``True``. + * - ``fp16`` (**smdistributed-modelparallel**>=v1.10) + - bool + - ``False`` + - To run FP16 training, add ``"fp16"'": True`` to the smp configuration. + Other APIs remain the same between FP16 and FP32. + If ``fp16`` is enabled and when user calls ``smp.DistributedModel``, + the model will be wrapped with ``FP16_Module``, which converts the model + to FP16 dtype and deals with forward pass in FP16. + If ``fp16`` is enabled and when user calls ``smp.DistributedOptimizer``, + the optimizer will be wrapped with ``FP16_Optimizer``. * - ``fp16_params`` (**smdistributed-modelparallel**>=v1.6) - bool - ``False`` diff --git a/doc/api/training/smp_versions/archives.rst b/doc/api/training/smp_versions/archives.rst index fe893928ef..8c87476e99 100644 --- a/doc/api/training/smp_versions/archives.rst +++ b/doc/api/training/smp_versions/archives.rst @@ -3,6 +3,7 @@ .. toctree:: :maxdepth: 1 + v1_9_0.rst v1_6_0.rst v1_5_0.rst v1_4_0.rst diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 49085d9347..ee606b8c34 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,7 +10,7 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 (Latest) +Version 1.10.0 (Latest) =========================================== To use the library, reference the Common API documentation alongside the framework specific API documentation. diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst index b05413965c..f6d1db6f21 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch.rst @@ -2,7 +2,7 @@ PyTorch API =========== To use the PyTorch-specific APIs for SageMaker distributed model parallism, -you need to add the following import statement at the top of your training script. +import the ``smdistributed.modelparallel.torch`` package at the top of your training script. .. code:: python @@ -16,24 +16,33 @@ you need to add the following import statement at the top of your training scrip `_ to learn how to use the following API in your PyTorch training script. -.. class:: smp.DistributedModel +.. contents:: Topics + :depth: 1 + :local: + +smdistributed.modelparallel.torch.DistributedModel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. class:: smdistributed.modelparallel.torch.DistributedModel A sub-class of ``torch.nn.Module`` which specifies the model to be partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is the model to be partitioned. The returned ``DistributedModel`` object internally manages model parallelism and data parallelism. Only one model in the training script can be wrapped with - ``smp.DistributedModel``. + ``smdistributed.modelparallel.torch.DistributedModel``. **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = smp.DistributedModel(model) **Important**: The ``__call__`` and  ``backward`` method calls on the - ``smp.DistributedModel`` object (in the following example, the object - is \ ``model``) can only be made inside a ``smp.step``-decorated + ``smdistributed.modelparallel.torch.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smdistributed.modelparallel.torch.step``-decorated function. Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can @@ -78,7 +87,7 @@ you need to add the following import statement at the top of your training scrip In these examples, all ``__call__``  and ``backward`` method calls on the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside - a ``smp.step``-decorated function. + a ``smdistributed.modelparallel.torch.step``-decorated function. **Using DDP** @@ -89,7 +98,7 @@ you need to add the following import statement at the top of your training scrip Unlike the original DDP wrapper, when you use ``DistributedModel``, model parameters and buffers are not immediately broadcast across processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the - ``smp.step``-decorated function when the partition is done. + ``smdistributed.modelparallel.torch.step``-decorated function when the partition is done. **Parameters** @@ -160,7 +169,7 @@ you need to add the following import statement at the top of your training scrip - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` otherwise. Initialized to ``False`` when ``DistributedModel`` is first created. It becomes be ``True`` during the first call - to ``smp.step``-decorated function. Once the model is partitioned, the + to ``smdistributed.modelparallel.torch.step``-decorated function. Once the model is partitioned, the local parameters or local ``state_dict`` can be fetched using the following methods. @@ -240,7 +249,7 @@ you need to add the following import statement at the top of your training scrip Registers a callable ``hook`` to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during - the first call to ``smp.step``, but before the actual execution of the + the first call to ``smdistributed.modelparallel.torch.step``, but before the actual execution of the first forward pass. Returns a ``RemovableHandle`` object ``handle``, which can be used to remove the hook by calling ``handle.remove()``. @@ -252,7 +261,7 @@ you need to add the following import statement at the top of your training scrip .. function:: join( ) A context manager to be used in conjunction with an instance of - ``smp.DistributedModel`` to be able to train with uneven inputs across + ``smdistributed.modelparallel.torch.DistributedModel`` to be able to train with uneven inputs across participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped ``DistributedDataParallel`` instance. For more information, see: `join `__ @@ -276,9 +285,9 @@ you need to add the following import statement at the top of your training scrip `register_comm_hook `__ in the PyTorch documentation. - **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + **Behavior of** ``smdistributed.modelparallel.torch.DistributedModel`` **with Tensor Parallelism** - When a model is wrapped by ``smp.DistributedModel``, the library + When a model is wrapped by ``smdistributed.modelparallel.torch.DistributedModel``, the library immediately traverses the modules of the model object, and replaces the modules that are supported for tensor parallelism with their distributed counterparts. This replacement happens in place. If there are no other @@ -293,6 +302,8 @@ you need to add the following import statement at the top of your training scrip # register DistributedSubmodule as the distributed version of Submodule # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + import smdistributed.modelparallel.torch as smp + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) class MyModule(nn.Module): @@ -319,20 +330,20 @@ you need to add the following import statement at the top of your training scrip placement of model partitions into GPUs and the initial broadcast of model parameters and buffers across data-parallel ranks take place immediately. This is because it does not need to wait for the model - partition when ``smp.DistributedModel`` wrapper is called. For other + partition when ``smdistributed.modelparallel.torch.DistributedModel`` wrapper is called. For other cases with ``pipeline_parallel_degree`` greater than 1, the broadcast and device placement will be deferred until the first call of an - ``smp.step``-decorated function happens. This is because the first - ``smp.step``-decorated function call is when the model partitioning + ``smdistributed.modelparallel.torch.step``-decorated function happens. This is because the first + ``smdistributed.modelparallel.torch.step``-decorated function call is when the model partitioning happens if pipeline parallelism is enabled. - Because of the module replacement during the ``smp.DistributedModel`` + Because of the module replacement during the ``smdistributed.modelparallel.torch.DistributedModel`` call, any ``load_state_dict`` calls on the model, as well as any direct access to model parameters, such as during the optimizer creation, - should be done **after** the ``smp.DistributedModel`` call. + should be done **after** the ``smdistributed.modelparallel.torch.DistributedModel`` call. Since the broadcast of the model parameters and buffers happens - immediately during ``smp.DistributedModel`` call when the degree of + immediately during ``smdistributed.modelparallel.torch.DistributedModel`` call when the degree of pipeline parallelism is 1, using ``@smp.step`` decorators is not required when tensor parallelism is used by itself (without pipeline parallelism). @@ -340,9 +351,9 @@ you need to add the following import statement at the top of your training scrip For more information about the library's tensor parallelism APIs for PyTorch, see :ref:`smdmp-pytorch-tensor-parallel`. - **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + **Additional Methods of** ``smdistributed.modelparallel.torch.DistributedModel`` **for Tensor Parallelism** - The following are the new methods of ``smp.DistributedModel``, in + The following are the new methods of ``smdistributed.modelparallel.torch.DistributedModel``, in addition to the ones listed in the `documentation `__. @@ -383,24 +394,89 @@ you need to add the following import statement at the top of your training scrip - Returns an iterator that runs over ``(name, param)`` tuples, for ``param`` that is allreduced over the ``RDP_GROUP``. +smdistributed.modelparallel.torch.DistributedOptimizer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. class:: smdistributed.modelparallel.torch.DistributedOptimizer(optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, **dynamic_loss_args) -.. class:: smp.DistributedOptimizer + An optimizer wrapper for saving and loading optimizer states. - **Parameters** - - ``optimizer`` + :param optimizer: An optimizer object. + :type optimizer: object + :param static_loss_scale: Effective only for FP16 training. The default value is ``1.0``. + :type static_loss_scale: float + :param dynamic_loss_scale: Effective only for FP16 training. Set to ``True`` to + use dynamic loss scale. The default value is ``False``. + :type dynamic_loss_scale: boolean + :param dynamic_loss_args: Effective only for FP16 training. + If ``dynamic_loss_scale=True``, you can configure additional scale + parameters for dynamic loss scale. + The following list shows available parameters. - An optimizer wrapper for saving/loading optimizer states. This wrapper - returns ``optimizer`` with the following methods overridden: + * ``"init_scale"``: Default is ``2**32`` + * ``"scale_factor"``: Default is ``2.`` + * ``"scale_window"``: Default is ``1000`` + * ``"min_scale"``: Default is ``1`` + * ``"delayed_shift"``: Default is ``1`` + * ``"consecutive_hysteresis"``: Default is ``False`` + :type dynamic_loss_args: dict - .. function:: state_dict( ) + **Example usage of an FP32 Optimizer:** + + .. code:: python + + optimizer = torch.optim.AdaDelta(...) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer(optimizer) + + **Example usage of an FP16 Optimizer with static loss scale:** + + .. code:: python + + optimizer = torch.optim.AdaDelta(...) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=1.0 + ) + + **Example usage of an FP16 Optimizer with dynamic loss scale:** + + .. code:: python + + optimizer = torch.optim.AdaDelta(...) + optimizer = smdistributed.modelparallel.torch.DistributedOptimizer( + optimizer, + static_loss_scale=None, + dynamic_loss_scale=True, + dynamic_loss_args={ + "scale_window": 1000, + "min_scale": 1, + "delayed_shift": 2 + } + ) + + .. tip:: + + After you modify training scripts with + :class:`smdistributed.modelparallel.torch.DistributedModel` and + :class:`smdistributed.modelparallel.torch.DistributedOptimizer`, + use the SageMaker PyTorch estimator's distribution configuration to enable FP16 training. + You simply need to add ``"fp16": True`` to the ``smp_options`` config dictionary's + ``"parameters"`` key as shown in + `Using the SageMaker TensorFlow and PyTorch Estimators + `_. + For more information about available parameters for the ``smp_options`` config, + see :ref:`sm-sdk-modelparallel-general`. + + This wrapper returns an ``optimizer`` object with the following methods overridden: + + .. method:: state_dict( ) Returns the ``state_dict`` that contains optimizer state for the entire model. It first collects the ``local_state_dict`` and gathers and merges - the ``local_state_dict`` from all ``mp_rank``s to create a full + the ``local_state_dict`` from all ``mp_rank``\ s to create a full ``state_dict``. - .. function:: load_state_dict( ) + .. method:: load_state_dict( ) Same as the ``torch.optimizer.load_state_dict()`` , except: @@ -409,7 +485,7 @@ you need to add the following import statement at the top of your training scrip - The actual loading happens after the model partition so that each rank knows its local parameters. - .. function:: local_state_dict( ) + .. method:: local_state_dict( ) Returns the ``state_dict`` that contains the local optimizer state that belongs to the current \ ``mp_rank``. This @@ -418,34 +494,77 @@ you need to add the following import statement at the top of your training scrip ``state_dict`` contains elements corresponding to only the current partition, or to the entire model. - ​ -.. function:: smp.partition(index) - :noindex: - **Inputs** +smdistributed.modelparallel.torch Context Managers and Util Functions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: smdistributed.modelparallel.torch.model_creation(tensor_parallelism=False, dtype=None, **tensor_parallel_config) + + Context manager to create a ``torch`` model. This API combines both the + :class:`smdistributed.modelparallel.torch.tensor_parallelism` and + :class:`smdistributed.modelparallel.torch.delay_param_initialization` decorators, + so you can simply use this single context when creating the torch model. + + :param tensor_parallelism: Whether to enable tensor parallelism during model creation. + :type tensor_parallelism: boolean + :param dtype: The dtype to use when creating the model. It has the following rules. + + * If dtype is specified, it will be used during model creation. + * If dtype is not specified, the default dtype will be used during model creation, + which is usually FP32. This is for the best performance on CPU. + * Any model that causes out-of-memory problems with FP32 initialization + is recommended to be created with + :class:`smdistributed.modelparallel.torch.delayed_parameter_initialization`. + * ``FP16_Module`` casts the model back to FP16 if FP16 training is enabled + with the ``smp`` config. For more inforamtion about FP16 training + in SageMaker with the model parallel library, see `FP16 Training + `_ + in the *Amazon SageMaker Developer Guide*. + + :type dtype: ``torch.dtype`` + :param tensor_parallel_config: kwargs to specifiy other tensor parallel configs. + This is not used if ``tensor_parallelism`` is ``False``. + :type tensor_parallel_config: dict + + **Example Usage:** + + .. code:: python + + import smdistributed.modelparallel.torch as smp + + with smp.model_creation( + tensor_parallelism=smp.tp_size() > 1, + dtype=torch.float16 if args.fp16 else torch.get_default_dtype() + ): + model = MyModel(...) + +.. function:: smdistributed.modelparallel.torch.partition(index) - - ``index`` (int) - The index of the partition. + :param index: The index of the partition. + :type index: int A context manager which places all modules defined inside into the partition with ID ``index``.  The ``index`` argument must be less than the number of partitions. - Use ``smp.partition`` to implement manual partitioning. + Use ``smdistributed.modelparallel.torch.partition`` to implement manual partitioning. If ``"auto_partition"`` is ``True``, then the - ``smp.partition`` contexts are ignored. Any module that is not placed in - any ``smp.partition`` context is placed in the + ``smdistributed.modelparallel.torch.partition`` contexts are ignored. Any module that is not placed in + any ``smdistributed.modelparallel.torch.partition`` context is placed in the ``default_partition`` defined through the SageMaker Python SDK. - When ``smp.partition`` contexts are nested, the innermost context + When ``smdistributed.modelparallel.torch.partition`` contexts are nested, the innermost context overrides the rest (see the following example). In PyTorch, manual partitioning should be done inside the module \ ``__init__``, and the partition assignment applies to the modules that are *created* inside - the ``smp.partition`` context. + the ``smdistributed.modelparallel.torch.partition`` context. Example: .. code:: python + import smdistributed.modelparallel.torch as smp + class Model(torch.nn.Module):     def __init__(self):         with smp.partition(1): @@ -455,29 +574,41 @@ you need to add the following import statement at the top of your training scrip             self.child2 = Child2()            # child2 on partition 1         self.child3 = Child3()                # child3 on default_partition -.. function:: smp.get_world_process_group( ) +.. data:: smdistributed.modelparallel.torch.amp.GradScaler + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. function:: smdistributed.modelparallel.torch.delay_param_initialization(enabled=True) + + If enabled, it delays the initialization of parameters + to save CPU memory. That is, parameter initialization takes place + after the model is partitioned on GPUs. + +.. function:: smdistributed.modelparallel.torch.get_world_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all processes, which can be used with the ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_mp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_mp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``MP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.get_dp_process_group( ) +.. function:: smdistributed.modelparallel.torch.get_dp_process_group( ) Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the processes in the ``DP_GROUP`` which contains the current process, which can be used with the \ ``torch.distributed`` API. Requires ``"ddp": True`` in SageMaker Python SDK parameters. -.. function:: smp.is_initialized( ) +.. function:: smdistributed.modelparallel.torch.is_initialized( ) - Returns ``True`` if ``smp.init`` has already been called for the + Returns ``True`` if ``smdistributed.modelparallel.torch.init`` has already been called for the process, and ``False`` otherwise. .. function::smp.is_tracing( ) @@ -485,43 +616,38 @@ you need to add the following import statement at the top of your training scrip Returns ``True`` if the current process is running the tracing step, and ``False`` otherwise. -.. data:: smp.nn.FusedLayerNorm +.. data:: smdistributed.modelparallel.torch.nn.FusedLayerNorm `Apex Fused Layer Norm `__ is currently not - supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + supported by the library. ``smdistributed.modelparallel.torch.nn.FusedLayerNorm`` replaces ``apex`` ``FusedLayerNorm`` and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedNovoGrad +.. data:: smdistributed.modelparallel.torch.optimizers.FusedNovoGrad `Fused Novo Grad optimizer `__ is - currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + currently not supported by the library. ``smdistributed.modelparallel.torch.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.optimizers.FusedLamb +.. data:: smdistributed.modelparallel.torch.optimizers.FusedLamb `FusedLamb optimizer `__ - currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + currently doesn’t work with the library. ``smdistributed.modelparallel.torch.optimizers.FusedLamb`` replaces ``apex`` ``FusedLamb`` optimizer and provides the same functionality. This requires ``apex`` to be installed on the system. -.. data:: smp.amp.GradScaler - - `Torch AMP Gradscaler `__ - currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces - ``torch.amp.GradScaler`` and provides the same functionality. - .. _pytorch_saving_loading: -APIs for Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^ +smdistributed.modelparallel.torch APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. function:: smp.save( ) +.. function:: smdistributed.modelparallel.torch.save(obj, f, partial=True, pickel_module=picklemodule, pickle_protocol=2, ) - Saves an object. This operation is similar to ``torch.save()``, except + Saves an object. This operation is similar to `torch.save() + `_, except that it has an additional keyword argument, ``partial``, and accepts only string type for the argument ``f`` (file). If ``partial=True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` @@ -534,16 +660,16 @@ APIs for Saving and Loading - ``partial`` (bool, default= ``True``):  When set to ``True``, each ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` index to the saved file. If you want to be able to load - and further train a model that you save with ``smp.save()``, you must + and further train a model that you save with ``smdistributed.modelparallel.torch.save()``, you must set ``partial=True``. - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): A module used for pickling metadata and objects. - ``pickle_protocol``  (int, default=2): Can be specified to override the defaultprotocol. -.. function:: smp.load( ) +.. function:: smdistributed.modelparallel.torch.load(f, map_location, pickle_module, pickle_load_args, partial=True) - Loads an object saved with ``smp.save()`` from a file. + Loads an object saved with ``smdistributed.modelparallel.torch.save()`` from a file. Similar to, `torch.load() `__, except it has an additional keyword argument, ``partial``, and accepts @@ -565,10 +691,83 @@ APIs for Saving and Loading ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. Should be used when loading a model trained with the library. +.. function:: smdistributed.modelparallel.torch.save_checkpoint(path, tag, partial=True, model=None, optimizer=None, user_content=None, translate_if_full=True, num_kept_partial_checkpoints=None) + + Saves a checkpoint. While :class:`smdistributed.modelparallel.torch.save` saves + model and optimizer objects, + this function checkpoints model and optimizer and saves the checkpoints as separate files. + It creates checkpoint folders in the following structure. + + .. code:: text + + - path + - ${tag}_partial (folder for partial checkpoint) + - model_rankinfo.pt + - optimizer_rankinfo.pt + - fp16_states_rankinfo.pt + - user_content.pt + - $tag (checkpoint file for full checkpoint) + - user_content_$tag (user_content file for full checkpoint) + - newest (a file that indicates the newest checkpoint) + + **Parameters** + + * ``path`` (str) (required): Path to save the checkpoint. The library creates + the directory if it does not already exist. + For example, ``/opt/ml/checkpoint/model_parallel``. + * ``tag`` (str) (required): A tag for the current checkpoint, usually the train + steps. Note: tag needs to be the same across all ranks (GPU workers). + When ``partial=False`` this will be the checkpoint file name. + * ``partial`` (boolean) (default: True): Whether to save the partial checkpoint. + * ``model`` (:class:`smdistributed.modelparallel.torch.DistributedModel`) + (default: None): The model to save. It needs to an ``smp.DistributedModel`` object. + * ``optimizer`` (:class:`smdistributed.modelparallel.torch.DistributedOptimizer`) + (default: None): The optimizer to save. It needs to be an ``smp.DistributedOptimizer`` object. + * ``user_content`` (any) (default: None): User-defined content to save. + * ``translate_if_full`` (boolean) (default: True): Whether to translate the + full ``state_dict`` to HF ``state_dict`` if possible. + * ``num_kept_partial_checkpoints`` (int) (default: None): The maximum number + of partial checkpoints to keep on disk. + +.. function:: smdistributed.modelparallel.torch.resume_from_checkpoint(path, tag=None, partial=True, strict=True, load_optimizer_states=True, translate_function=None) + + While :class:`smdistributed.modelparallel.torch.load` loads saved + model and optimizer objects, this function resumes from a saved checkpoint file. + + **Parameters** + + * ``path`` (str) (required): Path to load the checkpoint. + * ``tag`` (str) (default: None): Tag of the checkpoint to resume. If not provided, + the library tries to locate the newest checkpoint from the saved newest file. + * ``partial`` (boolean) (default: True): Whether to load the partial checkpoint. + * ``strict`` (boolean) (default: True): Load with strict load, no extra key or + missing key is allowed. + * ``load_optimizer_states`` (boolean) (default: True): Whether to load ``optimizer_states``. + * ``translate_function`` (function) (default: None): function to translate the full + checkpoint into smdistributed.modelparallel format. + For supported models, this is not required. + + **Example usage** + + .. code:: python + + # Save + smp.save_checkpoint( + checkpoint_dir, + tag=f"total_steps{total_steps}", + partial=True, + model=model, + optimizer=optimizer, + user_content=user_content + num_kept_partial_checkpoints=args.num_kept_checkpoints) + + # Load: this will automatically load the newest checkpoint + user_content = smp.resume_from_checkpoint(path, partial=partial) + .. _pytorch_saving_loading_instructions: -General Instruction For Saving and Loading -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +General instruction on saving and loading +----------------------------------------- The library can save partial or full checkpoints. @@ -577,13 +776,13 @@ The library can save partial or full checkpoints. - For full checkpoints, the library saves a single checkpoint that contains entire model parameters. -When **saving** using ``smp.save()``, each rank only holds its own +When **saving** using ``smdistributed.modelparallel.torch.save()``, each rank only holds its own parameters. If you want to save the full model, there will be some communication between the ranks to create the full model. If you save checkpoints often, you should save partial checkpoints for best performance. -When **loading** using ``smp.load()``, the library can load either partial or | +When **loading** using ``smdistributed.modelparallel.torch.load()``, the library can load either partial or | full checkpoints or full checkpoints saved by a non-model-parallel model. If you want to resume training with a non-model-parallel model or do inference, you need a full checkpoint. @@ -592,6 +791,7 @@ The following is an example of how you can save and load a checkpoint: .. code:: python + import smdistributed.modelparallel.torch as smp # Original model and optimizer model = MyModel(...) optimizer = MyOpt(...) diff --git a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst index e0ea1ba6c8..de7d20aaa2 100644 --- a/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst +++ b/doc/api/training/smp_versions/latest/smd_model_parallel_pytorch_tensor_parallel.rst @@ -17,7 +17,7 @@ place on a best-effort basis for those module supported for tensor parallelism. Alternatively, you can directly import and use the library’s distributed modules in the model definition. -Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +Some of the supported modules (such as ``smdistributed.modelparallel.torch.nn.Transformer``) are high-level blocks that contain many operations. Because custom implementations (as opposed to the built-in PyTorch modules) are typically used for these high-level blocks, the library offers an API that you can use to register @@ -47,20 +47,20 @@ use is functionally equivalent to the distributed module. You can verify this by taking a look at the equivalent reference implementations in the :ref:`smdmp-tp-appendix`. These implementations are functionally equivalent to their distributed -versions in ``smp.nn`` module. +versions in ``smdistributed.modelparallel.torch.nn`` module. -.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. class:: smdistributed.modelparallel.torch.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) - - A class decorator that registers the ``dist_module`` class with + - A decorator class that registers the ``dist_module`` class with the module class that it is attached to. The hooks can be used to adapt to different interfaces used with ``__init__`` and ``forward`` methods. - **Arguments:** - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -89,6 +89,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + init_hook = lambda config: ((), config.to_dict()) # register smp.nn.DistributedTransformer @@ -101,7 +103,7 @@ versions in ``smp.nn`` module. def forward(self, hidden_states, attention_mask): ... -.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) +.. function:: smdistributed.modelparallel.torch.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) - When you do not have direct access to model definition code, you can use this API to similarly register a distributed module with @@ -111,10 +113,10 @@ versions in ``smp.nn`` module. - ``module_cls``: The existing module class that will be distributed. - - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + - ``dist_module``: A subclass of ``smdistributed.modelparallel.torch.nn.DistributedModule`` that implements the distributed version of the module class the decorator is attached to. Any distributed module class defined - in ``smp.nn`` module can be used. + in ``smdistributed.modelparallel.torch.nn`` module can be used. - ``init_hook``: A callable that translates the arguments of the original module ``__init__`` method to an ``(args, kwargs)`` tuple compatible with the arguments of the corresponding @@ -143,6 +145,8 @@ versions in ``smp.nn`` module. .. code:: python + import smdistributed.modelparallel.torch as smp + from somelibrary import MyTransformer init_hook = lambda config: ((), config.to_dict()) @@ -157,16 +161,7 @@ versions in ``smp.nn`` module. Supported Modules for Tensor Parallelism ---------------------------------------- -The following modules are supported for tensor -parallelism. - -- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) -- ``smp.nn.DistributedTransformerLMHead`` -- ``smp.nn.DistributedTransformer`` -- ``smp.nn.DistributedTransformerLayer`` -- ``smp.nn.DistributedAttentionLayer`` -- ``smp.nn.DistributedTransformerOutputLayer`` -- ``smp.nn.DistributedEmbedding`` +The following modules are supported for tensor parallelism. .. contents:: Topics :depth: 3 @@ -177,37 +172,51 @@ parallelism. Tensor Parallelism Module APIs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. class:: smp.nn.DistributedLinear(in_features, out_features) +- :class:`smdistributed.modelparallel.torch.nn.DistributedLinear` (implements ``nn.Linear``) +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedAttentionLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer` +- :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` - - Tensor-parallel implementation of the ``nn.Linear`` class. - Functionally equivalent to an ``nn.Linear`` module with the same - ``in_features`` and ``out_features``. In other words, - ``in_features`` and ``out_features`` are the number of *global* - channels across tensor-parallel ranks. - - **Arguments:** +.. class:: smdistributed.modelparallel.torch.nn.DistributedLinear(in_features, out_features) + + Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + + + - **Arguments:** - ``in_features``: The total number of input channels for the linear layer across all tensor-parallel ranks. - ``out_features``: The total number of output channels for the linear layer across all tensor-parallel ranks. -.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) - - - Constructs a distributed transformer model, including embeddings - and a single LM head. A word embedding of size - ``(vocab_size, hidden_size)`` is created, as well as a positional - embedding of size ``(num_positions, hidden_size)``, and the - embeddings are added together. If ``num_token_types`` is larger - than 0, a separate embedding of size - ``(num_token_types, hidden_size)`` is created, and further added - on top. - - The embeddings are fed through a ``DistributedTransformer``, and - if ``add_lm_head`` is ``True``, the output passes through a single - LM head, which is a linear module without bias whose weight is - tied to the word embeddings. - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest - of the arguments. - - **Methods:** +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) + + Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** - ``forward(self, inputs)`` @@ -223,22 +232,27 @@ Tensor Parallelism Module APIs - ``attention_mask`` is assumed to be a 0-1 tensor of shape ``[N, S]``, where 1 represents a masked position. -.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) - - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose - number is given by ``num_layers`` argument. For the other - arguments and methods, refer to - ``smp.nn.DistributedTransformerLayer``. - - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, - layer normalization is applied to both the input and the output of - the ``DistributedTransformer``, in addition to the intermediate - attention and transformer-output layers. + A sequence of :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`. -.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. + +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + + Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. - - Tensor-parallel implementation of a single transformer layer. - Number of attention heads, hidden size, and intermediate size - refer to the global quantities across all tensor-parallel ranks. - **Arguments:** - ``num_attention_heads``: The total number of attention heads @@ -336,15 +350,19 @@ Tensor Parallelism Module APIs and the next three tensors are the same as the input arguments. -.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) +.. class:: smdistributed.modelparallel.torch.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) + + A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. - - A distributed implementation for the attention block. Includes the - computation of the self- or cross-attention (context layer), - followed by a linear mapping and dropout, which is optionally - followed by the residual-connection and layer normalization. - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``cross_attention``: If ``True``, it computes the attentions with respect to the ``cross_states`` tensor of the ``forward`` @@ -383,30 +401,34 @@ Tensor Parallelism Module APIs - A single tensor that is the output of the attention layer. -.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) +.. class:: smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) - Distributed implementation of a single transformer output layer. A - single :class:`smp.nn.DistributedTransformerLayer` with + single :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` with ``add_cross_attention=False`` consists of a single ``DistributedAttentionLayer`` immediately followed by a single ``DistributedTransformerOutputLayer``. The latter linearly maps the last channel of the input tensor from ``hidden_size`` to ``intermediate_size``, and then maps it back to ``hidden_size``. + + For more information about what's the reference implementation of this module, + see :ref:`smdmp-tp-appendix`. + - **Arguments:** - - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedTransformerLayer` for descriptions of the arguments. - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow (NaN loss values) for large models with more than 100 billion parameters when using FP16. (Default: False) -.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) +.. class:: smdistributed.modelparallel.torch.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) - Distributed implementation of a single Embedding Layer. Currently only supports splitting across the embedding_dim. - **Arguments:** - - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + - See :class:`smdistributed.modelparallel.torch.nn.DistributedEmbedding` for descriptions of the arguments. .. _enabling-tp: @@ -417,7 +439,7 @@ Enabling Tensor Parallelism There are two ways tensor parallelism can be enabled. First, you can use -the distributed module implementations in ``smp.nn`` module directly in +the distributed module implementations in ``smdistributed.modelparallel.torch.nn`` module directly in your model definition. See :ref:`smdmp-supported-modules-for-tp` for a complete list of built-in distributed modules. Here is an example of how this can be done: @@ -446,7 +468,7 @@ of code, which will automatically enable tensor parallelism for the supported modules within that scope. To do this, you can use the following API: -.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) +.. decorator:: smdistributed.modelparallel.torch.tensor_parallelism(enabled=True, **kwargs) - A context manager that enables or disables tensor parallelism for any supported module that is created inside. If there are nested @@ -463,6 +485,8 @@ following API: .. code:: python + import smdistributed.modelparallel.torch as smp + with smp.tensor_parallelism(): self.m0 = nn.Linear(20, 20) # will be distributed with smp.tensor_parallelism(enabled=False): @@ -472,7 +496,7 @@ following API: the distributed modules created inside the context. If a keyword argument provided through it matches any ``__init__`` method arguments of a ``DistributedModule`` that substitutes a module created inside - the ``smp.tensor_parallelism`` context, this keyword will override + the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - (*For v1.7.0 and later*) Through the following additional keyword arguments, @@ -481,21 +505,21 @@ following API: - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. By default, it is set to ``True``. You can deactivate it by setting - ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + ``fused_softmax=False`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. By default, it is set to ``False``. You can activate it by setting - ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + ``fused_bias_gelu=True`` in the ``smdistributed.modelparallel.torch.tensor_parallelism`` context manager. -.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) +.. function:: smdistributed.modelparallel.torch.set_tensor_parallelism(module, enabled=True, **kwargs) - Enables or disables tensor parallelism for the supported submodules of ``module``. If enabling, the outermost supported modules will be distributed. If disabling, tensor parallelism will be disabled for the entire module subtree of ``module``. Unlike the context manager, this API can be used after the model creation - (but before wrapping with :class:`smp.DistributedModel`), so direct + (but before wrapping with :class:`smdistributed.modelparallel.torch.DistributedModel`), so direct access to model definition code is not required. If a supported module shares weights with another (supported or unsupported) module, or if its hyperparameters do not support distribution @@ -504,14 +528,16 @@ following API: - Keyword arguments ``kwargs`` can be used to modify the configurations of the distributed modules created inside the context. If a keyword argument provided here matches any - ``__init__`` method arguments of a :class:`smp.DistributedModel` that - substitutes a module created inside the ``smp.tensor_parallelism`` + ``__init__`` method arguments of a :class:`smdistributed.modelparallel.torch.DistributedModel` that + substitutes a module created inside the ``smdistributed.modelparallel.torch.tensor_parallelism`` context, this keyword will override the value defined in the ``init_hook``. - **Example:** .. code:: python + import smdistributed.modelparallel.torch as smp + model = MyModel() smp.set_tensor_parallelism(model.encoder, True) smp.set_tensor_parallelism(model.encoder.embedding, True) @@ -608,7 +634,7 @@ in the *SageMaker's Distributed Model Parallel developer guide*. any tuples received. If the checkpointed layer takes a tuple as input, then this needs to be set to True. -.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") +.. class:: smdistributed.modelparallel.torch.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") - This API is recommended when importing pretrained models from libraries, such as PyTorch and Hugging Face Transformers. This is @@ -673,8 +699,8 @@ parses the arguments to ``__init__`` methods and sets the relevant attributes of the module, such as ``hidden_size`` and ``num_attention_heads``. -``smp.nn.DistributedTransformer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -692,8 +718,8 @@ attributes of the module, such as ``hidden_size`` and def forward(self, inp): return self.seq_layers(inp) -``smp.nn.DistributedTransformerLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -727,8 +753,8 @@ attributes of the module, such as ``hidden_size`` and else: return output, attention_mask -``smp.nn.DistributedAttentionLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python @@ -812,8 +838,8 @@ attributes of the module, such as ``hidden_size`` and else: return self_attention -``smp.nn.DistributedTransformerOutputLayer`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +``smdistributed.modelparallel.torch.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. code:: python diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst new file mode 100644 index 0000000000..b4713b2707 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_common_api.rst @@ -0,0 +1,538 @@ +Common API +========== + +The following SageMaker distribute model parallel APIs are common across all frameworks. + +.. contents:: Table of Contents + :depth: 3 + :local: + +The Library's Core APIs +----------------------- + +This API document assumes you use the following import statement in your training scripts. + +**TensorFlow** + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +**PyTorch** + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. function:: smp.init( ) + :noindex: + + Initialize the library. Must be called at the beginning of training script. + +.. function:: @smp.step(non_split_inputs, input_split_axes, [*args, **kwargs]) + :noindex: + + A decorator that must be placed over a function that represents a single + forward and backward pass (for training use cases), or a single forward + pass (for evaluation use cases). Any computation that is defined inside + the ``smp.step``-decorated function is executed in a pipelined manner. + + By default, every tensor input to the function is split across its batch + dimension into a number of microbatches specified while launching the + training job. This behavior can be customized through the arguments to + ``smp.step``, described below. The library then orchestrates the execution of + each microbatch across all partitions, based on the chosen pipeline + type. + + In a typical use case, forward pass and back-propagation are executed + inside an \ ``smp.step``-decorated function and gradients, loss, and + other relevant metrics (such as accuracy, etc.) are returned from + ``smp.step``-decorated function. + + Any gradient post-processing operation, such as gradient clipping and + allreduce, as well as ``optimizer.apply_gradients`` calls (for TF) or + ``optimizer.step`` (for PT) should be applied on the gradients returned + from the ``smp.step`` function, and not inside the ``smp.step`` + function. This is because every operation inside ``smp.step`` is + executed once per microbatch, so having these operations inside + ``smp.step`` can either be inefficient (in the case of allreduce), or + lead to wrong results (in the case of ``apply_gradients`` / + ``optimizer.step``). + + If the objects returned from the ``smp.step``-decorated function contain + ``tf.Tensor``\ s / ``torch.Tensor``\ s, they are converted to + ``StepOutput`` objects. A ``StepOutput`` object encapsulates all + versions of the tensor across different microbatches + (see ``StepOutput`` entry for more information). + + The argument to ``smp.step`` decorated function should either be a tensor + or an instance of list, tuple, dict or set for it to be split across + microbatches. If your object doesn't fall into this category, you can make + the library split your object, by implementing ``smp_slice`` method. + + Below is an example of how to use it with PyTorch. + + .. code:: python + + class CustomType: + def __init__(self, tensor): + self.data = tensor + + # The library will call this to invoke slicing on the object passing in total microbatches (num_mb) + # and the current microbatch index (mb). + def smp_slice(self, num_mb, mb, axis): + dim_size = list(self.data.size())[axis] + + split_size = dim_size // num_mb + sliced_tensor = self.data.narrow(axis, mb * split_size, split_size) + return CustomType(sliced_tensor, self.other) + + custom_obj = CustomType(torch.ones(4,)) + + @smp.step() + def step(custom_obj): + loss = model(custom_obj) + model.backward(loss) + return loss + + + **Important:** ``smp.step`` splits the batch into microbatches, and + executes everything inside the decorated function once per microbatch. + This might affect the behavior of batch normalization, any operation + that explicitly uses the batch size information, or any other Python + code that is expected to run once. + + **TensorFlow-specific behavior** + + ``smp.step`` is a wrapper that + inherits from and extends the behavior of ``tf.function``, and as such, + all the caveats that apply to the use of ``tf.function``\ s also apply + to ``smp.step``. In particular, any operation that is inside + ``smp.step`` executes in graph mode, and not eager mode. + + In the first call, ``smp.step`` performs tracing of the wrapped function every time + one of the tensor arguments changes their shape or dtype, or for every + new value of a Python argument, if there is one. Tracing is expensive, + so such scenarios should be avoided as much as possible or, + alternatively, an ``input_signature`` argument must be provided. For + more information on the usage of ``tf.function``, refer to the + TensorFlow documentation: + + - https://www.tensorflow.org/api_docs/python/tf/function\ + - https://www.tensorflow.org/guide/function\ + + Each ``smp.step`` decorated function must have a return value that depends on the + output of ``smp.DistributedModel``. + + **Common parameters** + + - ``non_split_inputs`` (``list``): The list of arguments to the decorated function + that should not be split along the batch dimension. Should be used + for all input tensors that do not have a batch dimension. Should be a + list of argument names as ``str``, as they appear in the signature of + the ``smp.step``-decorated function. By default it is considered an + empty list. + + - ``input_split_axes`` (``dict``): A dict that maps the argument name to its batch + axis. The keys should be the argument names as ``str``, as they + appear in the signature of the ``smp.step``-decorated function.  By + default all batch axes are assumed to be the 0-axis. + + **TensorFlow-only parameters** + + - All arguments of ``tf.function``. Note: + The \ ``experimental_compile`` argument of ``tf.function`` may not + work as expected with ``smp.step``, since it interferes with + pipelining and model partitioning. To enable XLA with the library, you can + instead use \ ``tf.config.optimizer.set_jit(True)``. + + **PyTorch-only parameters** + + - ``detach_outputs`` (``bool``) : If ``True``, calls ``torch.Tensor.detach()`` on + all returned ``torch.Tensor`` outputs. Setting it to ``False`` + increases memory consumption, unless ``detach()`` is manually called + on the returned tensors, because the model graph is not cleared from + memory after the training step. Set to \ ``True`` by default. + + **Returns** + + - The same object(s) returned from the decorated function. All + returned \ ``tf.Tensor``, \ ``tf.Variable``  objects (for TF) or + ``torch.Tensor`` objects (for PT) are wrapped inside + a \ ``StepOutput`` object, even when they are inside a Python + ``list``, ``tuple``, or ``dict``. + + + +.. class:: StepOutput + :noindex: + + + A class that encapsulates all versions of a ``tf.Tensor`` + or \ ``torch.Tensor`` across all microbatches. + + When a particular ``tf.Tensor`` or ``torch.Tensor`` is computed inside + ``smp.step``, different versions of the tensor are computed for each + microbatch. + + When this tensor is returned from ``smp.step`` and is accessed outside + of the decorated function, it appears as a ``StepOutput`` object, which + contains all such versions. For example, + + - In the case of Tensorflow, the gradient for a particular + ``tf.Variable`` is computed on each microbatch individually, and if + this gradient is returned from ``smp.step``, all gradients for this + ``tf.Variable`` become part of the same ``StepOutput`` object. The + ``StepOutput`` class offers the following API for commonly-used + post-processing operations on such tensors. + - In the case of PyTorch, the loss for each microbatch is computed + individually and all the ``torch.Tensor``\ s that represent the loss + for different microbatches become part of same ``StepOutput`` object, + if loss is returned from the ``smp.step`` function. + + + The ``StepOutput`` class offers the following API for commonly-used + post-processing operations on tensors. + + .. data:: StepOutput.outputs + :noindex: + + Returns a list of the underlying tensors, indexed by microbatch. + + .. function:: StepOutput.reduce_mean( ) + :noindex: + + Returns a ``tf.Tensor``, ``torch.Tensor`` that averages the constituent ``tf.Tensor`` s + ``torch.Tensor`` s. This is commonly used for averaging loss and gradients across microbatches. + + .. function:: StepOutput.reduce_sum( ) + :noindex: + + Returns a ``tf.Tensor`` / + ``torch.Tensor`` that sums the constituent + ``tf.Tensor``\ s/\ ``torch.Tensor``\ s. + + .. function:: StepOutput.concat( ) + :noindex: + + Returns a + ``tf.Tensor``/``torch.Tensor`` that concatenates tensors along the + batch dimension using ``tf.concat`` / ``torch.cat``. + + .. function:: StepOutput.stack( ) + :noindex: + + Applies ``tf.stack`` / ``torch.stack`` + operation to the list of constituent ``tf.Tensor``\ s / + ``torch.Tensor``\ s. + + **TensorFlow-only methods** + + .. function:: StepOutput.merge( ) + :noindex: + + Returns a ``tf.Tensor`` that + concatenates the constituent ``tf.Tensor``\ s along the batch + dimension. This is commonly used for merging the model predictions + across microbatches. + + .. function:: StepOutput.accumulate(method="variable", var=None) + :noindex: + + Functionally the same as ``StepOutput.reduce_mean()``. However, it is + more memory-efficient, especially for large numbers of microbatches, + since it does not wait for all constituent \ ``tf.Tensor``\ s to be + ready to start averaging them, thereby saving memory. + + In some cases (XLA for example) ``StepOutput.reduce_mean()`` might end + up being more memory-efficient than ``StepOutput.accumulate()``. + + **Parameters** + + - ``method`` (``"add_n"`` or ``"accumulate_n"`` or ``"variable"``): + If ``"add_n"`` or ``"accumulate_n"``, the library uses + ``tf.add_n`` and ``tf.accumulate_n``, respectively, to implement + accumulation. If ``"variable"``, the library uses an internal ``tf.Variable`` + into which to accumulate the tensors. Default is \ ``"variable"``. + Note: Memory usage behavior of these choices can depend on the model + and implementation. + + - ``var``: A ``tf.Variable`` into which, if provided, the library uses to + accumulate the tensors. If \ ``None``, the library internally creates a + variable. If ``method`` is not ``"variable"``, this argument is + ignored. + +.. _mpi_basics: + :noindex: + +MPI Basics +---------- + +The library exposes the following basic MPI primitives to its Python API: + +**Global** + +- ``smp.rank()`` : The global rank of the current process. +- ``smp.size()`` : The total number of processes. +- ``smp.get_world_process_group()`` : + ``torch.distributed.ProcessGroup`` that contains all processes. +- ``smp.CommGroup.WORLD``: The communication group corresponding to all processes. +- ``smp.local_rank()``: The rank among the processes on the current instance. +- ``smp.local_size()``: The total number of processes on the current instance. +- ``smp.get_mp_group()``: The list of ranks over which the current model replica is partitioned. +- ``smp.get_dp_group()``: The list of ranks that hold different replicas of the same model partition. + +**Tensor Parallelism** + +- ``smp.tp_rank()`` : The rank of the process within its + tensor-parallelism group. +- ``smp.tp_size()`` : The size of the tensor-parallelism group. +- ``smp.get_tp_process_group()`` : Equivalent to + ``torch.distributed.ProcessGroup`` that contains the processes in the + current tensor-parallelism group. +- ``smp.CommGroup.TP_GROUP`` : The communication group corresponding to + the current tensor parallelism group. + +**Pipeline Parallelism** + +- ``smp.pp_rank()`` : The rank of the process within its + pipeline-parallelism group. +- ``smp.pp_size()`` : The size of the pipeline-parallelism group. +- ``smp.get_pp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current pipeline-parallelism group. +- ``smp.CommGroup.PP_GROUP`` : The communication group corresponding to + the current pipeline parallelism group. + +**Reduced-Data Parallelism** + +- ``smp.rdp_rank()`` : The rank of the process within its + reduced-data-parallelism group. +- ``smp.rdp_size()`` : The size of the reduced-data-parallelism group. +- ``smp.get_rdp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current reduced data parallelism + group. +- ``smp.CommGroup.RDP_GROUP`` : The communication group corresponding + to the current reduced data parallelism group. + +**Model Parallelism** + +- ``smp.mp_rank()`` : The rank of the process within its model-parallelism + group. +- ``smp.mp_size()`` : The size of the model-parallelism group. +- ``smp.get_mp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current model-parallelism group. +- ``smp.CommGroup.MP_GROUP`` : The communication group corresponding to + the current model parallelism group. + +**Data Parallelism** + +- ``smp.dp_rank()`` : The rank of the process within its data-parallelism + group. +- ``smp.dp_size()`` : The size of the data-parallelism group. +- ``smp.get_dp_process_group()`` : ``torch.distributed.ProcessGroup`` + that contains the processes in the current data-parallelism group. +- ``smp.CommGroup.DP_GROUP`` : The communication group corresponding to + the current data-parallelism group. + +.. _communication_api: + :noindex: + +Communication API +----------------- + +The library provides a few communication primitives which can be helpful while +developing the training script. These primitives use the following +``enum`` s as arguments to specify which processes the communication +should involve. +​ + +**Helper structures** + +.. data:: smp.CommGroup + :noindex: + + An ``enum`` that takes the values + ``CommGroup.WORLD``, ``CommGroup.MP_GROUP``, and ``CommGroup.DP_GROUP``. + These values can also be accessed as ``smp.WORLD``, ``smp.MP_GROUP``, + and ``smp.DP_GROUP`` respectively. + + - ``CommGroup.WORLD``: Represents the entire group of processes used in + training + - ``CommGroup.MP_GROUP``: Represents the group of processes that hold + the same model replica as the current process. The processes in a + single ``MP_GROUP`` collectively store an entire replica of the + model. + - ``CommGroup.DP_GROUP``: Represents the group of processes that hold + the same model partition as the current process. The processes in a + single ``DP_GROUP`` perform data parallelism/allreduce among + themselves. + +.. data:: smp.RankType + :noindex: + + An ``enum`` that takes the values + ``RankType.WORLD_RANK``, ``RankType.MP_RANK``, and ``RankType.DP_RANK``. + + - ``RankType.WORLD_RANK``: The associated rank is to be interpreted as + the rank of the process across all processes used in training. + - ``RankType.MP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``MP_GROUP``. + - ``RankType.DP_RANK``: The associated rank is to be interpreted as the + rank of the process within the ``DP_GROUP``. + + +**Communication primitives:** + +.. function:: smp.broadcast(obj, group) + :noindex: + + Sends the object to all processes in the + group. The receiving process must call ``smp.recv_from`` to receive the + sent object. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be broadcast. + + - ``group``: A ``CommGroup`` argument that represents to which group of + processes the object will be sent. + + **Notes** + + - When you use ``broadcast`` on the sender process, there needs + to be an accompanying ``smp.recv_from()`` call on the receiver + processes. + + - This is a synchronous call; the ``broadcast`` statement + returns only after all ranks participating in the call have made a + matching ``recv_from`` call. + + **Example** + + .. code:: python + + if smp.rank() == 0: +     smp.broadcast(something, group=smp.CommGroup.WORLD) + else: +     smp.recv_from(0, rank_type=smp.RankType.WORLD_RANK) + +.. function:: smp.send(obj, dest_rank, rank_type) + :noindex: + + Sends the object ``obj`` to + ``dest_rank``, which is of a type specified by ``rank_type``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be sent. + + - ``dest_rank`` (``int``): An integer denoting the rank of the receiving process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``dest_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then ``obj`` is sent to process + with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the current + process. + + **Notes** + + - Note: \ This is a synchronous call; the ``send`` statement returns + only after the destination rank has made a matching + ``recv_from`` call. + +.. function:: smp.recv_from(src_rank, rank_type) + :noindex: + + Receive an object from a peer process. Can be used with a matching + ``smp.send`` or a ``smp.broadcast`` call. + + **Inputs** + + - ``src_rank`` (``int``): An integer denoting rank of the sending process. + + - ``rank_type`` (``enum``): A ``smp.RankType`` ``enum`` that determines how + ``dest_rank`` is to be interpreted. For example if ``src_rank`` is 1 + and ``rank_type`` is ``MP_RANK``, then the object is received from + the process with ``mp_rank`` 1 in the ``MP_GROUP`` which contains the + current process. + + **Returns** + + Returns the python object that is sent by the peer process. + + **Notes** + + - Note: This is a synchronous call; the ``recv_from`` statement returns + only after the source rank has made a matching ``send`` or + ``broadcast`` call, and the object is received. + +.. function:: smp.allgather(obj, group) + :noindex: + + A collective call that gathers all the + submitted objects across all ranks in the specified ``group``. Returns a + list whose ``i``\ th index contains the object submitted by the + ``i``\ th rank in ``group``. + + **Inputs** + + - ``obj``: An arbitrary picklable Python object that will be + allgathered. + + - ``group`` : A ``CommGroup`` argument that represents which group of + processes participate in ``allgather``. + + **Notes** + + - Note: This is a synchronous call; the ``allgather`` statement returns + only after all ranks participating in the call have made a matching + ``allgather`` call, and all the objects are received at the current + rank. + + **Examples** + + .. code:: python + + # assuming mp_size() == 2 + + if smp.mp_rank() == 0: +     out = smp.allgather(obj1, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + else: +     out = smp.allgather(obj2, smp.CommGroup.MP_GROUP)  # returns [obj1, obj2] + +.. function:: smp.barrier(group=smp.WORLD) + :noindex: + + A statement that hangs until all + processes in the specified group reach the barrier statement, similar to + ``MPI_Barrier()``. + + **Inputs** + + - ``group``: An ``smp.CommGroup`` ``enum`` that specifies the group of + processes participating in the barrier call. Defaults to + ``smp.WORLD``. + + **Examples** + + - Assume there are 8 processes and 2 model partitions, and + therefore 4 \ ``mp_group``\ s, and 2 ``dp_group``\ s. If + the \ ``barrier`` call is passed the value ``smp.MP_GROUP`` for its + group argument, then each process only waits until the other process + of its own ``mp_group`` reaches that point. It does not wait for + processes outside that ``mp_group``. + +.. function:: smp.dp_barrier() + :noindex: + + Same as passing ``smp.DP_GROUP``\ to ``smp.barrier()``. + Waits for the processes in the same \ ``dp_group`` as + the current process to reach the same point in execution. + +.. function:: smp.mp_barrier() + :noindex: + + Same as passing ``smp.MP_GROUP`` to + ``smp.barrier()``. Waits for the processes in the same ``mp_group`` as + the current process to reach the same point in execution. diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst new file mode 100644 index 0000000000..88d1a42165 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch.rst @@ -0,0 +1,677 @@ +PyTorch API +=========== + +To use the PyTorch-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.torch as smp + + +.. tip:: + + Refer to + `Modify a PyTorch Training Script + `_ + to learn how to use the following API in your PyTorch training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of ``torch.nn.Module`` which specifies the model to be + partitioned. Accepts a ``torch.nn.Module`` object ``module`` which is + the model to be partitioned. The returned ``DistributedModel`` object + internally manages model parallelism and data parallelism. Only one + model in the training script can be wrapped with + ``smp.DistributedModel``. + + **Example:** + + .. code:: python + + model = smp.DistributedModel(model) + + **Important**: The ``__call__`` and  ``backward`` method calls on the + ``smp.DistributedModel`` object (in the following example, the object + is \ ``model``) can only be made inside a ``smp.step``-decorated + function. + + Since ``DistributedModel``  is a ``torch.nn.Module``, a forward pass can + be performed by calling the \ ``DistributedModel`` object on the input + tensors. + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + For a backward pass, one needs to call the backward function on + the \ ``DistributedModel`` object, with tensors and gradients as + arguments, replacing the PyTorch operations \ ``torch.Tensor.backward`` + or ``torch.autograd.backward``. + + The API for ``model.backward`` is very similar to + ``torch.autograd.backward``. For example, the following + ``backward`` calls: + + .. code:: python + + torch.autograd.backward(loss) or loss.backward() + + should be replaced with: + + .. code:: python + + model.backward(loss) # loss is a tensor with only one element as its data + + Similarly, for non-scalar tensors, replace the following + ``backward`` call containing incoming gradient arguments: + + .. code:: python + + torch.autograd.backward(outputs, out_grads) + + with the following line: + + .. code:: python + + model.backward(outputs, out_grads) + + In these examples, all ``__call__``  and ``backward`` method calls on + the model objects (``model(inputs)`` and ``model.backward(loss)``) must be made inside + a ``smp.step``-decorated function. + + **Using DDP** + + If DDP is enabled with the SageMaker model parallel library, do not not place a PyTorch + ``DistributedDataParallel`` wrapper around the ``DistributedModel`` because + the ``DistributedModel`` wrapper will also handle data parallelism. + + Unlike the original DDP wrapper, when you use ``DistributedModel``, + model parameters and buffers are not immediately broadcast across + processes when the wrapper is called. Instead, the broadcast is deferred to the first call of the + ``smp.step``-decorated function when the partition is done. + + **Parameters** + + - ``module`` (``torch.nn.Module``): Module to be distributed (data parallelism and model parallelism). + + - ``trace_device`` (``"cpu"`` or ``"gpu"``) (default: ``"gpu"``) + Whether to perform the tracing step on the GPU or CPU. The tracing step gathers + information on the order of execution of modules, the shapes of + intermediate outputs, and execution times, to be used by the + partitioning algorithm. If ``trace_device`` is set to GPU, accurate + module execution times can be gathered during tracing for potentially + improved partitioning decision. However, if the model is too large to + fit in a single GPU, then ``trace_device`` should be set to CPU. + + - ``trace_execution_times`` (``bool``) (default: ``False``): If ``True``, + the library profiles the execution time of each module during tracing, and uses + it in the partitioning decision. This improves the partitioning + decision, but it might make the tracing slower. It may also introduce + some degree of non-determinism in partitioning results, because of the + inherent randomness in module execution times. Must be ``False`` if + ``trace_device`` is ``"cpu"``. + + - ``overlapping_allreduce`` (``bool``) (default: ``True``): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` while launching training). The library uses this flag + to decide whether to do overlapping allreduce whenever a parameter + gradients are ready. This leads to overlapping of communication and + computation and can improve performance. If this is set to ``False`` , + allreduce is performed at the end of the step. + + - ``backward_passes_per_step`` (``int``) (default: 1): This is only + applicable for hybrid data parallelism/model parallelism use cases (when + ``ddp`` is set to ``True`` in config). This parameter indicates the + number of backward passes to perform before calling allreduce on DDP. + This allows accumulating updates over multiple mini-batches before + reducing and applying them. + + - ``average_grads_across_microbatches`` (``bool``) (default: ``True``): + Whether or not the computed gradients should be averaged across + microbatches. If ``False``, the computed gradients will be summed across + microbatches, but not divided by the number of microbatches. In typical + use case where the computed loss is averaged over the mini-batch, this + should be left as ``True``. If you use a loss function that only sums + the per-sample loss across the batch (and not divide by the batch size), + then this must be set to ``False`` for correctness. + + - ``bucket_cap_mb`` (default: 25): \ ``DistributedDataParallel`` buckets + parameters into multiple buckets so that gradient reduction of each + bucket can potentially overlap with backward + computation. \ ``bucket_cap_mb``\ controls the bucket size in MegaBytes + (MB). + + - ``trace_memory_usage`` (default: False): When set to True, the library attempts + to measure memory usage per module during tracing. If this is disabled, + memory usage will be estimated through the sizes of tensors returned from + the module. + + - ``broadcast_buffers`` (default: True): Flag to be used with ``ddp=True``. + This parameter is forwarded to the underlying ``DistributedDataParallel`` wrapper. + Please see: `broadcast_buffer `__. + + - ``gradient_as_bucket_view`` (default: False): To be + used with ``ddp=True``. This parameter is forwarded to the underlying + ``DistributedDataParallel`` wrapper. Please see `gradient_as_bucket_view `__. + + **Properties** + + - ``partitioned``: Is ``True`` if the model is partitioned, ``False`` + otherwise. Initialized to ``False`` when ``DistributedModel`` is first + created. It becomes be ``True`` during the first call + to ``smp.step``-decorated function. Once the model is partitioned, the + local parameters or local ``state_dict`` can be fetched using the + following methods. + + **Methods** + + .. function:: backward(tensors, grad_tensors) + :noindex: + + Triggers a distributed backward + pass across model partitions. Example usage provided in the previous + section. The API is very similar + to https://pytorch.org/docs/stable/autograd.html#torch.autograd.backward. + ``retain_grad`` and ``create_graph``  flags are not supported. + + .. function:: local_buffers( ) + :noindex: + + Returns an iterator over buffers for the modules in + the partitioned model that have been assigned to the current process. + + .. function:: local_named_buffers( ) + :noindex: + + Returns an iterator over buffers for the + modules in the partitioned model that have been assigned to the current + process. This yields both the name of the buffer as well as the buffer + itself. + + .. function:: local_parameters( ) + :noindex: + + Returns an iterator over parameters for the + modules in the partitioned model that have been assigned to the current + process. + + .. function:: local_named_parameters( ) + :noindex: + + Returns an iterator over parameters for + the modules in the partitioned model that have been assigned to the + current process. This yields both the name of the parameter as well as + the parameter itself. + + .. function:: local_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. + + .. function:: local_named_modules( ) + :noindex: + + Returns an iterator over the modules in the + partitioned model that have been assigned to the current process. This + yields both the name of the module as well as the module itself. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains local + parameters that belong to the current \ ``mp_rank``. This ``state_dict`` + contains a key \ ``_smp_is_partial`` to indicate this is a + partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains parameters + for the entire model. It first collects the \ ``local_state_dict``  and + gathers and merges the \ ``local_state_dict`` from all ``mp_rank``\ s to + create a full ``state_dict``. Please note that this needs to be called on all ranks with + ``dp_rank()==0`` to ensure the gather happens properly. + If it is only called on all such ranks, it can hang. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.module.load_state_dict()`` , + except: It first gathers and merges the ``state_dict``\ s across + ``mp_rank``\ s, if they are partial. The actual loading happens after the + model partition so that each rank knows its local parameters. + + .. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. Returns a ``RemovableHandle`` object ``handle``, + which can be used to remove the hook by calling ``handle.remove()``. + + .. function:: cpu( ) + :noindex: + + Allgathers parameters and buffers across all ``mp_rank``\ s and moves them + to the CPU. + + .. function:: join( ) + :noindex: + + A context manager to be used in conjunction with an instance of + ``smp.DistributedModel`` to be able to train with uneven inputs across + participating processes. This is only supported when ``ddp=True``. This will use the join with the wrapped + ``DistributedDataParallel`` instance. For more information, see: + `join `__ + in the PyTorch documentation. + + .. function:: register_comm_hook( state, callable ) + :noindex: + + **Available for PyTorch 1.8.1 only** + Registers a communication hook which is an enhancement that provides + a flexible hook ``callable`` to users where they can specify how + gradients are aggregated across multiple workers. This method will be called on the wrapped ``DistributedDataParallel`` instance. + + Please note that when you register a comm hook you have full control of how the gradients are processed. + When using only data parallelism with Torch DDP you are expected to average grads across data parallel replicas within the hook. + Similarly, when using DistributedModel you have to averaging grads across data parallel replicas within the hook. + In addition to that, you also have to average grads across microbatches within the hook unless you explicitly desire to not average based on your loss function. + See ``average_grads_across_microbatches`` for more information about averaging grads across microbatches. + + This is only supported when ``ddp=True`` and ``overlapping_allreduce=True`` (default). + For more information, see: + `register_comm_hook `__ + in the PyTorch documentation. + + **Behavior of** ``smp.DistributedModel`` **with Tensor Parallelism** + + When a model is wrapped by ``smp.DistributedModel``, the library + immediately traverses the modules of the model object, and replaces the + modules that are supported for tensor parallelism with their distributed + counterparts. This replacement happens in place. If there are no other + references to the original modules in the script, they are + garbage-collected. The module attributes that previously referred to the + original submodules now refer to the distributed versions of those + submodules. + + **Example:** + + .. code:: python + + # register DistributedSubmodule as the distributed version of Submodule + # (note this is a hypothetical example, smp.nn.DistributedSubmodule does not exist) + smp.tp_register_with_module(Submodule, smp.nn.DistributedSubmodule) + + class MyModule(nn.Module): + def __init__(self): + ... + + self.submodule = Submodule() + ... + + # enabling tensor parallelism for the entire model + with smp.tensor_parallelism(): + model = MyModule() + + # here model.submodule is still a Submodule object + assert isinstance(model.submodule, Submodule) + + model = smp.DistributedModel(model) + + # now model.submodule is replaced with an equivalent instance + # of smp.nn.DistributedSubmodule + assert isinstance(model.module.submodule, smp.nn.DistributedSubmodule) + + If ``pipeline_parallel_degree`` (equivalently, ``partitions``) is 1, the + placement of model partitions into GPUs and the initial broadcast of + model parameters and buffers across data-parallel ranks take place + immediately. This is because it does not need to wait for the model + partition when ``smp.DistributedModel`` wrapper is called. For other + cases with ``pipeline_parallel_degree`` greater than 1, the broadcast + and device placement will be deferred until the first call of an + ``smp.step``-decorated function happens. This is because the first + ``smp.step``-decorated function call is when the model partitioning + happens if pipeline parallelism is enabled. + + Because of the module replacement during the ``smp.DistributedModel`` + call, any ``load_state_dict`` calls on the model, as well as any direct + access to model parameters, such as during the optimizer creation, + should be done **after** the ``smp.DistributedModel`` call. + + Since the broadcast of the model parameters and buffers happens + immediately during ``smp.DistributedModel`` call when the degree of + pipeline parallelism is 1, using ``@smp.step`` decorators is not + required when tensor parallelism is used by itself (without pipeline + parallelism). + + For more information about the library's tensor parallelism APIs for PyTorch, + see :ref:`smdmp-pytorch-tensor-parallel`. + + **Additional Methods of** ``smp.DistributedModel`` **for Tensor Parallelism** + + The following are the new methods of ``smp.DistributedModel``, in + addition to the ones listed in the + `documentation `__. + + .. function:: distributed_modules() + :noindex: + + - An iterator that runs over the set of distributed + (tensor-parallelized) modules in the model + + .. function:: is_distributed_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is distributed over + tensor-parallel ranks. + + .. function:: is_distributed_buffer(buf) + :noindex: + + - Returns ``True`` if the given buffer is distributed over + tensor-parallel ranks. + + .. function:: is_scaled_batch_parameter(param) + :noindex: + + - Returns ``True`` if the given ``nn.Parameter`` is operates on the + scaled batch (batch over the entire ``TP_GROUP``, and not only the + local batch). + + .. function:: is_scaled_batch_buffer(buf) + :noindex: + + - Returns ``True`` if the parameter corresponding to the given + buffer operates on the scaled batch (batch over the entire + ``TP_GROUP``, and not only the local batch). + + .. function:: default_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``DP_GROUP``. + + .. function:: scaled_batch_reducer_named_parameters() + :noindex: + + - Returns an iterator that runs over ``(name, param)`` tuples, for + ``param`` that is allreduced over the ``RDP_GROUP``. + + + +.. class:: smp.DistributedOptimizer + :noindex: + + **Parameters** + - ``optimizer`` + + An optimizer wrapper for saving/loading optimizer states. This wrapper + returns ``optimizer`` with the following methods overridden: + + .. function:: state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains optimizer state for the entire model. + It first collects the ``local_state_dict`` and gathers and merges + the ``local_state_dict`` from all ``mp_rank``s to create a full + ``state_dict``. + + .. function:: load_state_dict( ) + :noindex: + + Same as the ``torch.optimizer.load_state_dict()`` , except: + + - It first gathers and merges the local ``state_dict``\ s if they are + partial. + - The actual loading happens after the model partition so that each + rank knows its local parameters. + + .. function:: local_state_dict( ) + :noindex: + + Returns the ``state_dict`` that contains the + local optimizer state that belongs to the current \ ``mp_rank``. This + ``state_dict`` contains a key \ ``_smp_is_partial`` to indicate this is + a partial \ ``state_dict``, which indicates whether the + ``state_dict`` contains elements corresponding to only the current + partition, or to the entire model. + + ​ +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (int) - The index of the partition. + + A context manager which places all modules defined inside into the + partition with ID ``index``.  The ``index`` argument must be less than + the number of partitions. + + Use ``smp.partition`` to implement manual partitioning. + If ``"auto_partition"`` is ``True``, then the + ``smp.partition`` contexts are ignored. Any module that is not placed in + any ``smp.partition`` context is placed in the + ``default_partition`` defined through the SageMaker Python SDK. + + When ``smp.partition`` contexts are nested, the innermost context + overrides the rest (see the following example). In PyTorch, manual + partitioning should be done inside the module \ ``__init__``, and the + partition assignment applies to the modules that are *created* inside + the ``smp.partition`` context. + + Example: + + .. code:: python + + class Model(torch.nn.Module): +     def __init__(self): +         with smp.partition(1): +             self.child0 = Child0()            # child0 on partition 1 +             with smp.partition(2): +                 self.child1 = Child1()        # child1 on partition 2 +             self.child2 = Child2()            # child2 on partition 1 +         self.child3 = Child3()                # child3 on default_partition + +.. function:: smp.get_world_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of all + processes, which can be used with the ``torch.distributed`` API. + Requires ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_mp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``MP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.get_dp_process_group( ) + :noindex: + + Returns a ``torch.distributed`` ``ProcessGroup`` that consists of the + processes in the ``DP_GROUP`` which contains the current process, which + can be used with the \ ``torch.distributed`` API. Requires + ``"ddp": True`` in SageMaker Python SDK parameters. + +.. function:: smp.is_initialized( ) + :noindex: + + Returns ``True`` if ``smp.init`` has already been called for the + process, and ``False`` otherwise. + +.. function::smp.is_tracing( ) + :noindex: + :noindex: + + Returns ``True`` if the current process is running the tracing step, and + ``False`` otherwise. + +.. data:: smp.nn.FusedLayerNorm + :noindex: + + `Apex Fused Layer Norm `__ is currently not + supported by the library. ``smp.nn.FusedLayerNorm`` replaces ``apex`` + ``FusedLayerNorm`` and provides the same functionality. This requires + ``apex`` to be installed on the system. + +.. data:: smp.optimizers.FusedNovoGrad + :noindex: + + `Fused Novo Grad optimizer `__ is + currently not supported by the library. ``smp.optimizers.FusedNovoGrad`` replaces ``apex`` ``FusedNovoGrad`` + optimizer and provides the same functionality. This requires ``apex`` to + be installed on the system. + +.. data:: smp.optimizers.FusedLamb + :noindex: + + `FusedLamb optimizer `__ + currently doesn’t work with the library. ``smp.optimizers.FusedLamb`` replaces + ``apex`` ``FusedLamb`` optimizer and provides the same functionality. + This requires ``apex`` to be installed on the system. + +.. data:: smp.amp.GradScaler + :noindex: + + `Torch AMP Gradscaler `__ + currently doesn’t work with the library. ``smp.amp.GradScaler`` replaces + ``torch.amp.GradScaler`` and provides the same functionality. + +.. _pytorch_saving_loading: + :noindex: + +APIs for Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. function:: smp.save( ) + :noindex: + + Saves an object. This operation is similar to ``torch.save()``, except + it has an additional keyword argument, ``partial``, and accepts only + string type for the argument ``f`` (file). If ``partial=True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an ``mp_rank`` + index to your saved file. + + **Parameters** + + - ``obj`` (dict): A saved object. + - ``f`` (str): A string containing a file name. + - ``partial`` (bool, default= ``True``):  When set to ``True``, each + ``mp_rank`` saves a separate checkpoint file and the library adds an + ``mp_rank`` index to the saved file. If you want to be able to load + and further train a model that you save with ``smp.save()``, you must + set ``partial=True``. + - ``pickle_module`` (picklemodule, default = module ``"pickle"`` from ``"/opt/conda/lib/python3.6/pickle.py"``): + A module used for pickling metadata and objects. + - ``pickle_protocol``  (int, default=2): Can be specified to + override the defaultprotocol. + +.. function:: smp.load( ) + :noindex: + + Loads an object saved with ``smp.save()`` from a file. + + Similar to, `torch.load() `__, + except it has an additional keyword argument, ``partial``, and accepts + only string type for the argument ``f`` (file). If \ ``partial=True``, + then each ``mp_rank`` loads a separate checkpoint file. + + **Parameters** + + - ``f`` (string): A string containing a file name. + - ``map_location`` (function): A function + `torch.device `__, + a string, or a dict specifying how to remap storage locations. + - ``pickle_module`` (pickle module): A module used for unpickling + metadata and objects (has to match the \ ``pickle_module``\ used to + serialize file). + - ``pickle_load_args`` (Python 3 only): Optional keyword arguments + passed to ``pickle_module.load()`` and ``pickle_module.Unpickler()``. + - ``partial`` (bool, default= ``True``): When set to ``True``, each + ``mp_rank`` loads the checkpoint corresponding to the ``mp_rank``. + Should be used when loading a model trained with the library. + +.. _pytorch_saving_loading_instructions: + :noindex: + +General Instruction For Saving and Loading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The library can save partial or full checkpoints. + +- For partial checkpoints, each ``mp_rank`` saves its own checkpoint + file with only the parameters that belong to that rank. +- For full checkpoints, the library saves a single checkpoint that contains + entire model parameters. + +When **saving** using ``smp.save()``, each rank only holds its own +parameters. If you want to save the full model, there will be some +communication between the ranks to create the full model. If you save +checkpoints often, you should save partial checkpoints for best +performance. + +When **loading** using ``smp.load()``, the library can load either partial or | +full checkpoints or full checkpoints saved by a non-model-parallel model. If you +want to resume training with a non-model-parallel model or do inference, you need +a full checkpoint. + +The following is an example of how you can save and load a checkpoint: + +.. code:: python + + # Original model and optimizer + model = MyModel(...) + optimizer = MyOpt(...) + + # model parallel wrapper + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + + # To save, always save on dp_rank 0 to avoid data racing + if partial: +     # To save the partial model on each mp rank +     # the library will create `checkpoint.pt_{mprank}` for each mp rank +     if save_partial_model: +         if smp.dp_rank() == 0: +             model_dict = model.local_state_dict() # save the partial model +             opt_dict = optimizer.local_state_dict() # save the partial optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 f"/checkpoint.pt", +                 partial=True, +             ) + +     # To save the full model +     if save_full_model: +         if smp.dp_rank() == 0: +             model_dict = model.state_dict() # save the full model +             opt_dict = optimizer.state_dict() # save the full optimizer state +             smp.save( +                 {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, +                 "/checkpoint.pt", +                 partial=False, +             ) + + # To load, load on all ranks. + # The only difference for partial/full loading is the partial flag in smp.load + # Load partial checkpoint + if partial_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=True) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + # Load full checkpoint + if full_checkpoint: +    checkpoint = smp.load("/checkpoint.pt", partial=False) +    model.load_state_dict(checkpoint["model_state_dict"]) +    optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst new file mode 100644 index 0000000000..c66595ddf2 --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_pytorch_tensor_parallel.rst @@ -0,0 +1,876 @@ +.. _smdmp-pytorch-tensor-parallel: + :noindex: + +PyTorch API for Tensor Parallelism +================================== + +SageMaker distributed tensor parallelism works by replacing specific submodules +in the model with their distributed implementations. The distributed modules +have their parameters and optimizer states partitioned across tensor-parallel +ranks. This is to compute the same output as it would have been computed by +the original modules. Since tensor parallelism occurs across data-parallel +ranks, a rank might collect slices of the activations corresponding to the +data shards on other devices that are part of the same tensor parallelism group. + +You can enable or disable tensor parallelism for specific parts of the model. +Within the enabled parts, the replacements with distributed modules will take +place on a best-effort basis for those module supported for tensor parallelism. +Alternatively, you can directly import and use the library’s distributed +modules in the model definition. + +Some of the supported modules (such as ``smp.nn.Transformer``) are high-level +blocks that contain many operations. Because custom implementations +(as opposed to the built-in PyTorch modules) are typically used for these +high-level blocks, the library offers an API that you can use to register +specific distributed versions with such custom modules (provided that they +are functionally equivalent). This allows the library to automatically replace +the occurrences of such PyTorch modules with their distributed counterparts +provided by the library. +For more information, see the following topics. + +.. contents:: Topics + :depth: 3 + :local: + +.. _registering-tp-modules: + :noindex: + +Registering Tensor Parallelism Distributed Modules +-------------------------------------------------- + +Although PyTorch natively provides some of the commonly used (and +tensor-parallelizable) building blocks such as Transformer, users often +use custom implementations for such higher-level modules. To distribute +such modules with tensor parallelism, you need to register the +distributed modules to the custom module implementation in your class, +so that the library knows how to distribute the custom module. When you +register the distributed modules, make sure the custom module that you +use is functionally equivalent to the distributed module. You can verify +this by taking a look at the equivalent reference implementations in the +:ref:`smdmp-tp-appendix`. +These implementations are functionally equivalent to their distributed +versions in ``smp.nn`` module. + +.. decorator:: @smp.tp_register(dist_module, init_hook=None, forward_hook=None, return_hook=None) + + - A class decorator that registers the ``dist_module`` class with + the module class that it is attached to. The hooks can be used to + adapt to different interfaces used with ``__init__`` and + ``forward`` methods. + - **Arguments:** + + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer + # as the distributed version of MyTransformer + @smp.tp_register(smp.nn.DistributedTransformer, init_hook=init_hook) + class MyTransformer(nn.Module): + def __init__(self, config): + ... + + def forward(self, hidden_states, attention_mask): + ... + +.. function:: smp.tp_register_with_module(module_cls, dist_module, init_hook=None, forward_hook=None, return_hook=None) + :noindex: + + - When you do not have direct access to model definition code, you + can use this API to similarly register a distributed module with + an existing module class. + + - **Arguments:** + + - ``module_cls``: The existing module class that will be + distributed. + - ``dist_module``: A subclass of ``smp.nn.DistributedModule`` + that implements the distributed version of the module class the + decorator is attached to. Any distributed module class defined + in ``smp.nn`` module can be used. + - ``init_hook``: A callable that translates the arguments of the + original module ``__init__`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``__init__`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``__init__`` method (including argument order and default + values), except it must exclude ``self``. + - ``forward_hook``: A callable that translates the arguments of + the original module ``forward`` method to an ``(args, kwargs)`` + tuple compatible with the arguments of the corresponding + distributed module ``forward`` method. Must return a tuple, + whose first element is an iterable representing the positional + arguments, and second element is a ``dict`` representing the + keyword arguments. The input signature of the ``init_hook`` + must **exactly** match the signature of the original + ``forward`` method (including argument order and default + values), except it must exclude ``self``. + - ``return_hook``: A callable that translates the object returned + from the distributed module to the return object expected of + the original module. + + - **Example:** + + .. code:: python + + from somelibrary import MyTransformer + + init_hook = lambda config: ((), config.to_dict()) + + # register smp.nn.DistributedTransformer as the distributed version of MyTransformer + smp.tp_register_with_module(MyTransformer, + smp.nn.DistributedTransformer, + init_hook=init_hook) + +.. _smdmp-supported-modules-for-tp: + :noindex: + +Supported Modules for Tensor Parallelism +---------------------------------------- + +The following modules are supported for tensor +parallelism. + +- ``smp.nn.DistributedLinear`` (implements ``nn.Linear``) +- ``smp.nn.DistributedTransformerLMHead`` +- ``smp.nn.DistributedTransformer`` +- ``smp.nn.DistributedTransformerLayer`` +- ``smp.nn.DistributedAttentionLayer`` +- ``smp.nn.DistributedTransformerOutputLayer`` +- ``smp.nn.DistributedEmbedding`` + +.. contents:: Topics + :depth: 3 + :local: + +.. _tp-module-api: + :noindex: + +Tensor Parallelism Module APIs +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. class:: smp.nn.DistributedLinear(in_features, out_features) + :noindex: + + - Tensor-parallel implementation of the ``nn.Linear`` class. + Functionally equivalent to an ``nn.Linear`` module with the same + ``in_features`` and ``out_features``. In other words, + ``in_features`` and ``out_features`` are the number of *global* + channels across tensor-parallel ranks. + - **Arguments:** + + - ``in_features``: The total number of input channels for the + linear layer across all tensor-parallel ranks. + - ``out_features``: The total number of output channels for the + linear layer across all tensor-parallel ranks. + +.. class:: smp.nn.DistributedTransformerLMHead(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, vocab_size=30522, num_positions=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, num_token_types=0, causal_mask_size=None, add_cross_attention=False, add_lm_head=True, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Constructs a distributed transformer model, including embeddings + and a single LM head. A word embedding of size + ``(vocab_size, hidden_size)`` is created, as well as a positional + embedding of size ``(num_positions, hidden_size)``, and the + embeddings are added together. If ``num_token_types`` is larger + than 0, a separate embedding of size + ``(num_token_types, hidden_size)`` is created, and further added + on top. + - The embeddings are fed through a ``DistributedTransformer``, and + if ``add_lm_head`` is ``True``, the output passes through a single + LM head, which is a linear module without bias whose weight is + tied to the word embeddings. + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the rest + of the arguments. + - **Methods:** + + - ``forward(self, inputs)`` + + - If ``add_cross_attention`` is ``True``, ``inputs`` must be a + tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, cross_states, cross_states, cross_mask, labels)``. + - Otherwise, ``inputs`` must be a tuple + ``(input_ids, attention_mask, token_type_ids, position_ids, labels)``. + - If ``token_type_ids`` is ``None``, token type embedding will + not be used. + - ``input_ids`` is assumed to be of shape ``[N, S]``, where + ``N`` is the batch size and ``S`` is sequence length. + - ``attention_mask`` is assumed to be a 0-1 tensor of shape + ``[N, S]``, where 1 represents a masked position. + +.. class:: smp.nn.DistributedTransformer(num_layers=12, num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A sequence of ``smp.nn.DistributedTransformerLayer``\ s, whose + number is given by ``num_layers`` argument. For the other + arguments and methods, refer to + ``smp.nn.DistributedTransformerLayer``. + - If both ``pre_layernorm`` and ``post_layernorm`` are ``True``, + layer normalization is applied to both the input and the output of + the ``DistributedTransformer``, in addition to the intermediate + attention and transformer-output layers. + +.. class:: smp.nn.DistributedTransformerLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, intermediate_size=4096, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, causal_mask_size=None, add_cross_attention=False, pre_layernorm=False, post_layernorm=True) + :noindex: + + - Tensor-parallel implementation of a single transformer layer. + Number of attention heads, hidden size, and intermediate size + refer to the global quantities across all tensor-parallel ranks. + - **Arguments:** + + - ``num_attention_heads``: The total number of attention heads + across tensor-parallel ranks + - ``attention_head_size``: The number of channels of a single + attention head. + - ``hidden_size``: The hidden dimension of the transformer. The + input tensor ``hidden_states`` is assumed to have its last + dimension size equal to ``hidden_size``. + - ``intermediate_size``: The number of output channels in the + first linear transformation of the transformer output layer. + ``DistributedTransformerOutputLayer`` first maps + ``hidden_size`` dimensions of its input tensor into + ``intermediate_size`` dimensions, and then maps it back into + ``hidden_size`` dimensions. + - ``attention_dropout_prob``: The dropout probability applied to + the attention probabilities. + - ``hidden_dropout_prob``: The dropout probability used in + dropout layers other than the one applied to the attention + probabilities. + - ``activation``: Choice of activation function to use at the + output layer. Must be ``"gelu"`` or ``"relu"``. + - ``layernorm_epsilon``: The epsilon added to the denominator of + layer normalization for numerical stability. + - ``initializer_range``: If ``use_normal_initialization`` is + ``True``, the standard deviation of the normal random variable + to initialize the weights with. + - ``use_normal_initialization``: If ``True``, the weights are + initialized with normal distribution with standard deviation + given by ``initializer_range``. Otherwise, default PyTorch + initialization is used. + - ``causal_mask_size``: If ``None``, no causal mask is used on + attentions. Otherwise, should be set to maximum sequence length + to apply a causal mask to the attention scores. This is used, + for instance, in GPT-2. + - ``add_cross_attention``: If ``True``, a cross-attention layer + will be added after the self-attention block. The + cross-attention layer computes the attention keys and values + based on the ``cross_states`` input (instead of + ``hidden_states`` input, as in self-attention. This is used in + the decoder block of encoder-decoder architectures. For + encoder-only architectures that only use self-attention, this + should be kept ``False``. + - ``pre_layernorm``: If ``True``, inserts layer normalization at + the input. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + - ``post_layernorm``: If ``True``, inserts layer normalization at + the output. At least one of ``pre_layernorm`` and + ``post_layernorm`` must be ``True``. + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the transformer + layer. + + - **Arguments:** + + - If ``add_cross_attention=False``, ``inputs`` must be a + tuple ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the batch + size, and ``S`` is the sequence length. + - If ``add_cross_attention=True``, ``inputs`` must be a + tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is assumed to be a tensor of + dimensions ``[N, S_1, H]``, where ``N`` is batch size, + ``S_1`` is sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_1]``, where ``N`` is the batch + size, and ``S_1`` is the sequence length, and + ``cross_mask`` is assumed to be a tensor of size + ``[N, 1, 1, S_2]``. Keys and values for the attention + heads in the cross-attention layer (but not the + self-attention layer) are computed using + ``cross_states``, and ``cross_mask`` is applied as the + attention mask in the cross-attention layer (but not the + self-attention layer). + + - **Returns:** + + - If ``add_cross_attention=False``, a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is the output of the transformer, and + ``attention_mask`` is the same the ``attention_mask`` + argument. + - If ``add_cross_attention=True``, a tuple + ``(hidden_states, cross_states, attention_mask, cross_mask)``, + where ``hidden_states`` is the output of the transformer, + and the next three tensors are the same as the input + arguments. + +.. class:: smp.nn.DistributedAttentionLayer(num_attention_heads=32, attention_head_size=32, hidden_size=1024, attention_dropout_prob=0.1, hidden_dropout_prob=0.1, layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, cross_attention=False, causal_mask_size=None, pre_layernorm=False, post_layernorm=True) + :noindex: + + - A distributed implementation for the attention block. Includes the + computation of the self- or cross-attention (context layer), + followed by a linear mapping and dropout, which is optionally + followed by the residual-connection and layer normalization. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``cross_attention``: If ``True``, it computes the attentions + with respect to the ``cross_states`` tensor of the ``forward`` + method input tuple. (Default: ``False``) + + - **Methods:** + + - ``forward(self, inputs)``: Forward pass for the attention + layer. + + - **Arguments:** + + - If ``cross_attention=False``, ``inputs`` must be a tuple + ``(hidden_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S, H]``, where ``N`` is batch size, ``S`` is + sequence length, and ``H`` is ``hidden_size``. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S]``, where ``N`` is the + batch size, and ``S`` is the sequence length. + - If ``cross_attention=True``, ``inputs`` must be a tuple + ``(hidden_states, cross_states, attention_mask)``, where + ``hidden_states`` is assumed to be a tensor of dimensions + ``[N, S_1, H]``, where ``N`` is batch size, ``S_1`` is + sequence length, and ``H`` is ``hidden_size``. + ``cross_states`` is assumed to be a tensor of size + ``[N, S_2, H]``, similarly interpreted. + ``attention_mask`` is assumed to be a tensor of + dimensions ``[N, 1, 1, S_2]``, where ``N`` is the batch + size, and ``S_2`` is the sequence length. Keys and values + for the attention heads are computed using + ``cross_states``. + + - **Returns:** + + - A single tensor that is the output of the attention + layer. + +.. class:: smp.nn.DistributedTransformerOutputLayer(hidden_size=1024, intermediate_size=4096, hidden_dropout_prob=0.1, activation="gelu", layernorm_epsilon=1e-5, initializer_range=0.02, use_normal_initialization=False, pre_layernorm=False, post_layernorm=True, fp32_residual_addition=False) + :noindex: + + - Distributed implementation of a single transformer output layer. A + single :class:`smp.nn.DistributedTransformerLayer` with + ``add_cross_attention=False`` consists of a single + ``DistributedAttentionLayer`` immediately followed by a single + ``DistributedTransformerOutputLayer``. The latter linearly maps + the last channel of the input tensor from ``hidden_size`` to + ``intermediate_size``, and then maps it back to ``hidden_size``. + - **Arguments:** + + - See :class:`smp.nn.DistributedTransformerLayer` for descriptions of the + arguments. + - ``fp32_residual_addition``: Set to ``True`` if you want to avoid overflow + (NaN loss values) for large models with more than 100 billion parameters + when using FP16. (Default: False) + +.. class:: smp.nn.DistributedEmbedding(num_embeddings,embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, initializer_range=0.02, _skip_allgather=False,_skip_scatter_and_merge=False,) + :noindex: + + - Distributed implementation of a single Embedding Layer. Currently + only supports splitting across the embedding_dim. + - **Arguments:** + + - See :class:`smp.nn.DistributedEmbedding` for descriptions of the + arguments. + +.. _enabling-tp: + :noindex: + +Enabling Tensor Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There are two ways tensor parallelism can be enabled. + +First, you can use +the distributed module implementations in ``smp.nn`` module directly in +your model definition. See :ref:`smdmp-supported-modules-for-tp` +for a complete list of built-in distributed modules. Here is an example +of how this can be done: + +.. code:: python + + import torch.nn as nn + import smdistributed.modelparallel.torch as smp + + class TransformerModel: + def __init__(self): + self.embedding = nn.Embedding(vocab_size, hidden_size) + + # directly instantiate smp.nn.DistributedTransformer and use it + self.encoder = smp.nn.DistributedTransformer(num_layers, hidden_size, **kwargs) + + self.pooler = nn.Linear(hidden_size, hidden_size) + + def forward(self, hidden_states): + emb_out = self.embedding(hidden_states) + enc_out = self.encoder(emb_out) + return self.pooler(enc_out) + +Second, you can enable tensor parallelism for specific modules or blocks +of code, which will automatically enable tensor parallelism for the +supported modules within that scope. To do this, you can use the +following API: + +.. decorator:: smp.tensor_parallelism(enabled=True, **kwargs) + :noindex: + + - A context manager that enables or disables tensor parallelism for + any supported module that is created inside. If there are nested + contexts, the innermost overrides the rest. If there are + multiple supported modules created within the context, where one + is the submodule of the other, only the outermost module will be + distributed. If a supported module shares weights with another + (supported or unsupported) module, or if its hyperparameters do + not support distribution (e.g., not divisible by the tensor + parallelism degree), tensor parallelism will **not** be enabled + for this module even if this API is used. + + **Example:** + + .. code:: python + + with smp.tensor_parallelism(): + self.m0 = nn.Linear(20, 20) # will be distributed + with smp.tensor_parallelism(enabled=False): + self.m1 = nn.Linear(20, 20) # will not be distributed + + - ``kwargs`` - Keyword arguments that can be used to modify the configurations of + the distributed modules created inside the context. + If a keyword argument provided through it matches any ``__init__`` method arguments + of a ``DistributedModule`` that substitutes a module created inside + the ``smp.tensor_parallelism`` context, this keyword will override + the value defined in the ``init_hook``. + + - (*For v1.7.0 and later*) Through the following additional keyword arguments, + the library supports `NVIDIA Megatron’s fused kernels + `_ + + - ``fused_softmax`` (bool) - Fusion of attention masking and softmax. + By default, it is set to ``True``. You can deactivate it by setting + ``fused_softmax=False`` in the ``smp.tensor_parallelism`` context manager. + - ``fused_bias_gelu`` (bool) - Fusion of bias addition and Gelu activation. + By default, it is set to ``False``. You can activate it by setting + ``fused_bias_gelu=True`` in the ``smp.tensor_parallelism`` context manager. + + + +.. function:: smp.set_tensor_parallelism(module, enabled=True, **kwargs) + :noindex: + + - Enables or disables tensor parallelism for the supported + submodules of ``module``. If enabling, the outermost supported + modules will be distributed. If disabling, tensor parallelism will + be disabled for the entire module subtree of ``module``. Unlike + the context manager, this API can be used after the model creation + (but before wrapping with :class:`smp.DistributedModel`), so direct + access to model definition code is not required. If a supported + module shares weights with another (supported or unsupported) + module, or if its hyperparameters do not support distribution + (e.g., not divisible by the tensor parallelism degree), tensor + parallelism will **not** be enabled for this module. + - Keyword arguments ``kwargs`` can be used to modify the + configurations of the distributed modules created inside the + context. If a keyword argument provided here matches any + ``__init__`` method arguments of a :class:`smp.DistributedModel` that + substitutes a module created inside the ``smp.tensor_parallelism`` + context, this keyword will override the value defined in the + ``init_hook``. + - **Example:** + + .. code:: python + + model = MyModel() + smp.set_tensor_parallelism(model.encoder, True) + smp.set_tensor_parallelism(model.encoder.embedding, True) + + # outermost supported submodules in model.encoder will be distributed, except for + # model.encoder.embedding + model = smp.DistributedModel(model) + optimizer = smp.DistributedOptimizer(optimizer) + +.. _activation-checkpointing-api: + :noindex: + +Activation Checkpointing APIs +----------------------------- + +``smdistributed.modelparallel`` provides three APIs to enable +activation checkpointing: one for checkpointing modules, +one for checkpointing sequential modules, and +one for checkpointing pretrained models. + +For a conceptual guide and examples, see +`Activation Checkpointing `_ +in the *SageMaker's Distributed Model Parallel developer guide*. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint(module, *args, preserve_rng_state=True) + :noindex: + + - Checkpoints the module passed. Throws error if, during manual + partitioning, all children of module are not on same rank as the + module itself, i.e. the module tree is split across multiple + partitions. During auto-partitioning, if the module is split + across multiple partitions, then this call is ignored(with a + warning). Note that this call applies to the module instance only, + not to the module class. + + - **Arguments:** + + - ``module (Instance of nn.Module)``: The module to be + checkpointed. Note that unlike native checkpointing in + PyTorch’s, activation checkpointing in + ``smdistributed.modelparallel`` is at the granularity of a + module. A generic function cannot be passed here. + - ``args``: Tuple containing inputs to the module. + - ``preserve_rng_state (bool, default=True)``: Omit stashing and + restoring the RNG state during each checkpoint. + +.. class:: smdistributed.modelparallel.torch.patches.checkpoint.checkpoint_sequential(sequential_module, input, strategy="each", preserve_rng_state=True, pack_args_as_tuple=False) + :noindex: + + - Checkpoints the modules inside + `nn.Sequential `__. + This can be used even if different layers that are part of the + sequential container lie on different partitions. Each layer part + of the sequential module that is checkpointed must lie completely + within one partition. If this is not the case during manual + partitioning, then an error will be thrown. If this is not the + case during auto partitioning, a warning will be raised and this + module will be run without checkpointing. + + - **Arguments** + + - ``sequential_module (nn.Sequential)``: the sequential module to + be checkpointed. + - ``input (torch.Tensor or a tuple of torch.Tensors)``: input to + the module, which can be a tensor or a tuple of tensors. If a + tuple is passed, then pack_args_as_tuple should be set to True. + - ``strategy (string, default=“each”)`` : Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. This determines how much + memory can be reduced. It can take the following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example, if a sequential consists of + [a, b, c, d] where a,b are on pp_rank0 and c,d are on + pp_rank 1, then this strategy would checkpoint a,b together + and then c,d together. This means effectively, inputs of a, + outputs of b, inputs of c, and outputs of d are in memory; + the reamining activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x layers together on a best effort basis. + It can group x layers together if there are x layers + consecutively on the same partition. For example: + [a,b,c,d,e] where a,b are on pp_rank0 and c,d,e are on + pp_rank 1. If the strategy is ``group_3,`` then a,b are + checkpointed together on pp_rank0 and c,d,e are checkpointed + together on pp_rank1. + + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the checkpointed layer takes a tuple as + input, then this needs to be set to True. + +.. class:: smp.set_activation_checkpointing(module, preserve_rng_state=True, pack_args_as_tuple=False, strategy="each") + :noindex: + + - This API is recommended when importing pretrained models from + libraries, such as PyTorch and Hugging Face Transformers. This is + particularly useful when you don’t have access to the model + definition code and not be able to replace a module call with + checkpoint. + + - **Arguments**: + + - ``module (Instance of nn.Module or nn.Sequential)``: The module + to checkpoint. + - ``preserve_rng_state (bool, default=True)``: Set to ``False`` + to omit stashing and restoring the RNG state during each + checkpoint. + - ``pack_args_as_tuple (bool, default=False)``: *Can only be + passed when module is a sequential module.* To ensure that + backward works correctly, the autograd function has to unpack + any tuples received. If the layer checkpointed takes a tuple as + input, then this needs to be set to True. + - ``strategy: (string, default=“each”)``: *Can only be passed + when module is a sequential module.* Strategy determines how + many layers part of the sequential module need to be grouped + together for one checkpointing call. + - This determines how much memory can be reduced. It can take the + following values + + - ``each`` : The default is to checkpoint each module inside + the sequential separately. + - ``contiguous``: Groups consecutive layers on the same + partition together. For example if a sequential consists of + ``[a, b, c, d]`` where ``a, b`` are on ``pp_rank0`` and ``c, d`` are on + ``pp_rank 1``, then this strategy would checkpoint a,b together + and then ``c, d`` together. This means effectively, the inputs of + ``a``, outputs of ``b``, inputs of ``c``, and outputs of ``d`` are in + memory, and the rest of the activations are recomputed. + - ``group_2, group_3, group_4, etc:`` More generally, + ``group_x`` where x is an integer. This strategy provides + more flexibility in how many layers to group together. + ``group_x`` groups x number of layers together on a best + effort basis if there are x layers consecutively in the same + partition. **Example**: Assume a module with layers ``[a, b, + c, d, e]``. The layers a and b are on pp_rank0, and ``c``, ``d``, and + ``e`` are on ``pp_rank 1``. If the strategy is ``group_3,`` then ``a``, + ``b`` are checkpointed together on ``pp_rank0``, and ``c``, ``d``, ``e`` are + checkpointed together on ``pp_rank1``. + +.. _smdmp-tp-appendix: + :noindex: + +Appendix: Reference Implementations for Modules +----------------------------------------------- + +The following are reference implementations for transformer-related +modules. Note that this is not the actual ``smdistributed`` source code, +but the distributed implementations provided in the library are the +distributed versions of these reference implementations, and can be used +to determine whether the distributed modules perform the same operations +as the custom modules in your script. + +To keep the implementations simple, we only assume keyword arguments, +and assume the existence of a method ``parse_args(kwargs)``, which +parses the arguments to ``__init__`` methods and sets the relevant +attributes of the module, such as ``hidden_size`` and +``num_attention_heads``. + +``smp.nn.DistributedTransformer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class Transformer(nn.Module): + def __init__(self, **kwargs): + super(Transformer, self).__init__() + self.parse_args(kwargs) + + self.layers = [] + for l in range(self.num_layers): + self.layers.append(TransformerLayer(**kwargs)) + + self.seq_layers = nn.Sequential(*self.layers) + + def forward(self, inp): + return self.seq_layers(inp) + +``smp.nn.DistributedTransformerLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerLayer, self).__init__() + self.parse_args(kwargs) + + self.attention = AttentionLayer(**kwargs) + self.output = TransformerOutputLayer(**kwargs) + + if self.add_cross_attention: + self.cross_attention = AttentionLayer(cross_attention=True, **kwargs) + + def forward(self, inp): + if self.add_cross_attention: + hidden_states, cross_states, attention_mask, cross_mask = inp + else: + hidden_states, attention_mask = inp + + attention_output = self.attention((hidden_states, attention_mask)) + if self.add_cross_attention: + attention_output = self.cross_attention((attention_output, + cross_states, + cross_mask)) + + output = self.output(attention_output) + + if self.add_cross_attention: + return output, cross_states, attention_mask, cross_mask + else: + return output, attention_mask + +``smp.nn.DistributedAttentionLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class AttentionLayer(nn.Module): + def __init__(self, **kwargs): + super(AttentionLayer, self).__init__() + self.parse_args(kwargs) + self.attention_head_size = self.hidden_size // self.num_attention_heads + + self.query = nn.Linear(self.hidden_size, self.hidden_size) + self.key = nn.Linear(self.hidden_size, self.hidden_size) + self.value = nn.Linear(self.hidden_size, self.hidden_size) + self.dense = nn.Linear(self.hidden_size, self.hidden_size) + + self.dropout1 = nn.Dropout(self.attention_dropout_prob) + self.dropout2 = nn.Dropout(self.hidden_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def transpose(self, tensor, key=False): + shape = tensor.size()[:-1] + + (self.num_attention_heads, self.attention_head_size) + tensor = torch.reshape(tensor, shape) + if key: + return tensor.permute(0, 2, 3, 1) + else: + return tensor.permute(0, 2, 1, 3) + + def forward(self, inp): + if self.cross_attention: + hidden_states, cross_states, attention_mask = inp + else: + hidden_states, attention_mask = inp + + if self.pre_layernorm: + norm_states = self.pre_layernorm(hidden_states) + else: + norm_states = hidden_states + + query_layer = self.query(norm_states) + + if self.cross_attention: + key_layer = self.key(cross_states) + value_layer = self.value(cross_states) + else: + key_layer = self.key(norm_states) + value_layer = self.value(norm_states) + + query_layer = self.transpose(query_layer) + key_layer = self.transpose(key_layer, key=True) + value_layer = self.transpose(value_layer) + + attention_scores = torch.matmul(query_layer, key_layer) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + if not self.cross_attention and self.causal_mask is not None: + attention_scores = self.apply_causal_mask(attention_scores) + + attention_scores = attention_scores + attention_mask + + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = self.dropout1(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.local_attention_size,) + context_layer = torch.reshape(context_layer, new_context_layer_shape) + + self_attention = self.dense(context_layer) + self_attention = self.dropout2(self_attention) + + if self.post_layernorm: + return self.layernorm(self_attention + hidden_states) + else: + return self_attention + +``smp.nn.DistributedTransformerOutputLayer`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. code:: python + + class TransformerOutputLayer(nn.Module): + def __init__(self, **kwargs): + super(TransformerOutputLayer, self).__init__() + self.parse_args(kwargs) + + self.dense1 = nn.Linear(self.hidden_size, self.intermediate_size) + self.dense2 = nn.Linear(self.intermediate_size, self.hidden_size) + + self.dropout = nn.Dropout(self.attention_dropout_prob) + + if self.pre_layernorm: + self.pre_layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + if self.post_layernorm: + self.layernorm = nn.LayerNorm(self.hidden_size, + eps=self.layernorm_epsilon) + + def forward(self, inp): + if self.pre_layernorm: + norm_inp = self.pre_layernorm(inp) + else: + norm_inp = inp + + dense1_output = self.dense1(norm_inp) + if self.activation == "gelu": + act_output = F.gelu(dense1_output) + else: + act_output = F.relu(dense1_output) + + dense2_output = self.dense2(act_output) + output = self.dropout(dense2_output) + + if self.post_layernorm: + return self.layernorm(inp + output) + else: + return output diff --git a/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst new file mode 100644 index 0000000000..2c658b487c --- /dev/null +++ b/doc/api/training/smp_versions/v1.9.0/smd_model_parallel_tensorflow.rst @@ -0,0 +1,171 @@ +TensorFlow API +============== + +To use the TensorFlow-specific APIs for SageMaker distributed model parallism, +you need to add the following import statement at the top of your training script. + +.. code:: python + + import smdistributed.modelparallel.tensorflow as smp + +.. tip:: + + Refer to + `Modify a TensorFlow Training Script + `_ + to learn how to use the following APIs in your TensorFlow training script. + +.. class:: smp.DistributedModel + :noindex: + + A sub-class of the Keras \ ``Model`` class, which defines the model to + be partitioned. Model definition is done by sub-classing + ``smp.DistributedModel`` class, and implementing the ``call()`` method, + in the same way as the Keras model sub-classing API. Any operation that + is part of the \ ``smp.DistributedModel.call()`` method is subject to + partitioning, meaning that every operation placed inside executes in + exactly one of the devices (the operations outside run on all devices). + + + Similar to the regular Keras API, the forward pass is done by directly + calling the model object on the input tensors. For example: + + .. code:: python + + predictions = model(inputs)   # model is a smp.DistributedModel object + + However, ``model()`` calls can only be made inside a + ``smp.step``-decorated function. + + The outputs from a ``smp.DistributedModel`` are available in all ranks, + regardless of which rank computed the last operation. + + **Methods:** + + .. function:: save_model(save_path="/opt/ml/model") + :noindex: + + **Inputs** + - ``save_path`` (``string``): A path to save an unpartitioned model with latest training weights. + + Saves the entire, + unpartitioned model with the latest trained weights to ``save_path`` in + TensorFlow ``SavedModel`` format. Defaults to ``"/opt/ml/model"``, which + SageMaker monitors to upload the model artifacts to Amazon S3. + +.. function:: smp.partition(index) + :noindex: + + **Inputs** + + - ``index`` (``int``): The index of the partition. + + A context manager which places all operations defined inside into the + partition whose ID is equal to ``index``. When + ``smp.partition`` contexts are nested, the innermost context overrides + the rest. The ``index`` argument must be smaller than the number of + partitions. + + ``smp.partition`` is used in the manual partitioning API; + if \ ``"auto_partition"`` parameter is set to ``True`` while launching + training, then ``smp.partition`` contexts are ignored. Any operation + that is not placed in any ``smp.partition`` context is placed in the + ``default_partition``, as shown in the following example: + + .. code:: python + + # auto_partition: False + # default_partition: 0 + smp.init() + [...] + x = tf.constant(1.2)                     # placed in partition 0 + with smp.partition(1): +     y = tf.add(x, tf.constant(2.3))      # placed in partition 1 +     with smp.partition(3): +         z = tf.reduce_sum(y)             # placed in partition 3 + + +.. function:: register_post_partition_hook(hook) + :noindex: + + Registers a callable ``hook`` to + be executed after the model is partitioned. This is useful in situations + where an operation needs to be executed after the model partition during + the first call to ``smp.step``, but before the actual execution of the + first forward pass. + + .. code:: python + + @smp.register_post_partition_hook + def test_eager(): + # All statements here will be executed right after partition but before the first forward pass + tf.print("Entered hook through eager context") + +.. class:: smp.CheckpointManager + :noindex: + + + A subclass of TensorFlow + `CheckpointManager `__, + which is used to manage checkpoints. The usage is similar to TensorFlow + ``CheckpointManager``. + + The following returns a ``CheckpointManager`` object. + + .. code:: python + + smp.CheckpointManager(checkpoint, +                       directory="/opt/ml/checkpoints", +                       max_to_keep=None, +                       checkpoint_name="ckpt") + + **Parameters** + + - ``checkpoint``: A `tf.train.Checkpoint + `__ instance + that represents a model checkpoint. + + - ``directory``: (``str``) The path to a directory in which to write + checkpoints. A file named "checkpoint" is also written to this + directory (in a human-readable text format) which contains the state + of the ``CheckpointManager``. Defaults to + ``"/opt/ml/checkpoints"``, which is the directory that SageMaker + monitors for uploading the checkpoints to Amazon S3. + - ``max_to_keep`` (``int``): The number of checkpoints to keep. If + ``None``, all checkpoints are kept. + - ``checkpoint_name`` (``str``): Custom name for the checkpoint file. + Defaults to ``"ckpt"``. + + + **Methods:** + + .. function:: save( ) + :noindex: + + Saves a new checkpoint in the specified directory. Internally uses ``tf.train.CheckpointManager.save()``. + + .. function:: restore( ) + :noindex: + + Restores the latest checkpoint in the specified directory. + Internally uses ``tf.train.CheckpointManager.restore()``. + + + **Examples:** + + .. code:: python + + checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) + ckpt_manager = smp.CheckpointManager(checkpoint, max_to_keep=5)  # use /opt/ml/checkpoints + + for inputs in train_ds: +     loss = train_step(inputs) +     # [...] +     ckpt_manager.save()  # save a new checkpoint in /opt/ml/checkpoints + + .. code:: python + + for step, inputs in enumerate(train_ds): +     if step == 0: +         ckpt_manager.restore() +     loss = train_step(inputs) diff --git a/doc/api/training/smp_versions/v1_9_0.rst b/doc/api/training/smp_versions/v1_9_0.rst new file mode 100644 index 0000000000..e2e9acd83a --- /dev/null +++ b/doc/api/training/smp_versions/v1_9_0.rst @@ -0,0 +1,13 @@ + +Version 1.7.0, 1.8.0, 1.8.1, 1.9.0 +================================== + +To use the library, reference the Common API documentation alongside the framework specific API documentation. + +.. toctree:: + :maxdepth: 1 + + v1.9.0/smd_model_parallel_common_api + v1.9.0/smd_model_parallel_pytorch + v1.9.0/smd_model_parallel_pytorch_tensor_parallel + v1.9.0/smd_model_parallel_tensorflow From 8be360e19702b693387a1831aead1b8a1caac411 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh <105655261+rahven14@users.noreply.github.com> Date: Sat, 16 Jul 2022 04:18:27 +0530 Subject: [PATCH 126/526] fix: enable model.register without 'inference' & 'transform' instances (#3228) * fix: enable model.register without 'inference_instances' & 'transform_instances' * remove failing integ test * fix: make instance_type optional for model registry model package and mandatory for marketplace model package * fix: black-check and flake8 errors --- src/sagemaker/session.py | 27 +++++- src/sagemaker/workflow/_utils.py | 9 +- .../workflow/test_pipeline_session.py | 95 ++++++++++++++++++- tests/unit/test_session.py | 92 ++++++++++++++++-- 4 files changed, 206 insertions(+), 17 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f426724b6c..145bf41cbe 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4499,9 +4499,32 @@ def get_create_model_package_request( "Containers": containers, "SupportedContentTypes": content_types, "SupportedResponseMIMETypes": response_types, - "SupportedRealtimeInferenceInstanceTypes": inference_instances, - "SupportedTransformInstanceTypes": transform_instances, } + if model_package_group_name is not None: + if inference_instances is not None: + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + } + ) + if transform_instances is not None: + inference_specification.update( + { + "SupportedTransformInstanceTypes": transform_instances, + } + ) + else: + if not all([inference_instances, transform_instances]): + raise ValueError( + "inference_instances and transform_instances " + "must be provided if model_package_group_name is not present." + ) + inference_specification.update( + { + "SupportedRealtimeInferenceInstanceTypes": inference_instances, + "SupportedTransformInstanceTypes": transform_instances, + } + ) request_dict["InferenceSpecification"] = inference_specification request_dict["CertifyForMarketplace"] = marketplace_cert request_dict["ModelApprovalStatus"] = approval_status diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 7a0a399299..f8a99996a5 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -341,16 +341,11 @@ def __init__( super(_RegisterModelStep, self).__init__( name, StepTypeEnum.REGISTER_MODEL, display_name, description, depends_on, retry_policies ) - deprecated_args_missing = ( - content_types is None - or response_types is None - or inference_instances is None - or transform_instances is None - ) + deprecated_args_missing = content_types is None or response_types is None if not (step_args is None) ^ deprecated_args_missing: raise ValueError( "step_args and the set of (content_types, response_types, " - "inference_instances, transform_instances) are mutually exclusive. " + ") are mutually exclusive. " "Either of them should be provided." ) diff --git a/tests/unit/sagemaker/workflow/test_pipeline_session.py b/tests/unit/sagemaker/workflow/test_pipeline_session.py index 13af00cf6a..eca3892390 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_session.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_session.py @@ -224,8 +224,6 @@ def test_pipeline_session_context_for_model_step_without_instance_types( ], "SupportedContentTypes": ["text/csv"], "SupportedResponseMIMETypes": ["text/csv"], - "SupportedRealtimeInferenceInstanceTypes": None, - "SupportedTransformInstanceTypes": None, }, "CertifyForMarketplace": False, "ModelApprovalStatus": "PendingManualApproval", @@ -234,3 +232,96 @@ def test_pipeline_session_context_for_model_step_without_instance_types( } assert register_step_args.create_model_package_request == expected_output + + +def test_pipeline_session_context_for_model_step_with_one_instance_types( + pipeline_session_mock, +): + model = Model( + name="MyModel", + image_uri="fakeimage", + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session_mock, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=_ROLE, + ) + register_step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_group_name="MyModelPackageGroup", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', + ) + + expected_output = { + "ModelPackageGroupName": "MyModelPackageGroup", + "InferenceSpecification": { + "Containers": [ + { + "Image": "fakeimage", + "Environment": { + "SAGEMAKER_PROGRAM": "dummy_script.py", + "SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code", + "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", + "SAGEMAKER_REGION": "us-west-2", + }, + "ModelDataUrl": ParameterString( + name="ModelData", + default_value="s3://my-bucket/file", + ), + "Framework": "TENSORFLOW", + "FrameworkVersion": "2.9", + "NearestModelName": "resnet50", + "ModelInput": { + "DataInputConfig": '{"input_1":[1,224,224,3]}', + }, + } + ], + "SupportedContentTypes": ["text/csv"], + "SupportedResponseMIMETypes": ["text/csv"], + "SupportedRealtimeInferenceInstanceTypes": ["ml.t2.medium", "ml.m5.xlarge"], + }, + "CertifyForMarketplace": False, + "ModelApprovalStatus": "PendingManualApproval", + "SamplePayloadUrl": "s3://test-bucket/model", + "Task": "IMAGE_CLASSIFICATION", + } + + assert register_step_args.create_model_package_request == expected_output + + +def test_pipeline_session_context_for_model_step_without_model_package_group_name( + pipeline_session_mock, +): + model = Model( + name="MyModel", + image_uri="fakeimage", + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session_mock, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=_ROLE, + ) + with pytest.raises(ValueError) as error: + model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + model_package_name="MyModelPackageName", + task="IMAGE_CLASSIFICATION", + sample_payload_url="s3://test-bucket/model", + framework="TENSORFLOW", + framework_version="2.9", + nearest_model_name="resnet50", + data_input_configuration='{"input_1":[1,224,224,3]}', + ) + assert ( + "inference_inferences and transform_instances " + "must be provided if model_package_group_name is not present." == str(error) + ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1fd58ea531..78df274b71 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2355,11 +2355,29 @@ def test_create_model_package_from_containers_incomplete_args(sagemaker_session) containers=containers, ) assert ( - "content_types, response_types, inference_inferences and transform_instances " + "content_types and response_types " "must be provided if containers is present." == str(error) ) +def test_create_model_package_from_containers_without_model_package_group_name(sagemaker_session): + model_package_name = "sagemaker-model-package" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + with pytest.raises(ValueError) as error: + sagemaker_session.create_model_package_from_containers( + model_package_name=model_package_name, + containers=containers, + content_types=content_types, + response_types=response_types, + ) + assert ( + "inference_inferences and transform_instances " + "must be provided if model_package_group_name is not present." == str(error) + ) + + def test_create_model_package_from_containers_all_args(sagemaker_session): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] @@ -2437,7 +2455,7 @@ def test_create_model_package_from_containers_all_args(sagemaker_session): def test_create_model_package_from_containers_without_instance_types(sagemaker_session): - model_package_name = "sagemaker-model-package" + model_package_group_name = "sagemaker-model-package-group-name-1.0" containers = ["dummy-container"] content_types = ["application/json"] response_types = ["application/json"] @@ -2470,7 +2488,7 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s containers=containers, content_types=content_types, response_types=response_types, - model_package_name=model_package_name, + model_package_group_name=model_package_group_name, model_metrics=model_metrics, metadata_properties=metadata_properties, marketplace_cert=marketplace_cert, @@ -2480,13 +2498,75 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s customer_metadata_properties=customer_metadata_properties, ) expected_args = { - "ModelPackageName": model_package_name, + "ModelPackageGroupName": model_package_group_name, "InferenceSpecification": { "Containers": containers, "SupportedContentTypes": content_types, "SupportedResponseMIMETypes": response_types, - "SupportedRealtimeInferenceInstanceTypes": None, - "SupportedTransformInstanceTypes": None, + }, + "ModelPackageDescription": description, + "ModelMetrics": model_metrics, + "MetadataProperties": metadata_properties, + "CertifyForMarketplace": marketplace_cert, + "ModelApprovalStatus": approval_status, + "DriftCheckBaselines": drift_check_baselines, + "CustomerMetadataProperties": customer_metadata_properties, + } + sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) + + +def test_create_model_package_from_containers_with_one_instance_types(sagemaker_session): + model_package_group_name = "sagemaker-model-package-group-name-1.0" + containers = ["dummy-container"] + content_types = ["application/json"] + response_types = ["application/json"] + transform_instances = ["ml.m5.xlarge"] + model_metrics = { + "Bias": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + drift_check_baselines = { + "Bias": { + "ConfigFile": { + "ContentType": "content-type", + "S3Uri": "s3://...", + } + } + } + + metadata_properties = { + "CommitId": "test-commit-id", + "Repository": "test-repository", + "GeneratedBy": "sagemaker-python-sdk", + "ProjectId": "unit-test", + } + marketplace_cert = (True,) + approval_status = ("Approved",) + description = "description" + customer_metadata_properties = {"key1": "value1"} + sagemaker_session.create_model_package_from_containers( + containers=containers, + content_types=content_types, + response_types=response_types, + transform_instances=transform_instances, + model_package_group_name=model_package_group_name, + model_metrics=model_metrics, + metadata_properties=metadata_properties, + marketplace_cert=marketplace_cert, + approval_status=approval_status, + description=description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + ) + expected_args = { + "ModelPackageGroupName": model_package_group_name, + "InferenceSpecification": { + "Containers": containers, + "SupportedContentTypes": content_types, + "SupportedResponseMIMETypes": response_types, + "SupportedTransformInstanceTypes": transform_instances, }, "ModelPackageDescription": description, "ModelMetrics": model_metrics, From 3a67537dd4e673b9e0ae0bb773e9b7a980a0a773 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 18 Jul 2022 19:15:17 +0000 Subject: [PATCH 127/526] prepare release v2.100.0 --- CHANGELOG.md | 23 +++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2987f13ac0..f754b6fccb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,28 @@ # Changelog +## v2.100.0 (2022-07-18) + +### Features + + * upgrade to support python 3.10 + * Add target_model to support multi-model endpoints + * Added support for feature group schema change and feature parameters + +### Bug Fixes and Other Changes + + * enable model.register without 'inference' & 'transform' instances + * rename RegisterModel inner steps to prevent duplicate step names + * remove primitive_or_expr() from conditions + * support pipeline variables for spark processors run arguments + * make 'ModelInput' field optional for inference recommendation + * Fix processing image uri param + * fix: neo inferentia as compilation target not using framework ver + +### Documentation Changes + + * SageMaker model parallel library v1.10.0 documentation + * add detail & links to clarify docstrings + ## v2.99.0 (2022-07-08) ### Features diff --git a/VERSION b/VERSION index 60c21330a6..e24025aa94 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.99.1.dev0 +2.100.0 From ff8a613945486e76cb63350b26054ad8aca4c1bd Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 18 Jul 2022 19:15:18 +0000 Subject: [PATCH 128/526] update development version to v2.100.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index e24025aa94..9d5c4490e9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.100.0 +2.100.1.dev0 From ae4afc261d8c396c6457eb8fae4e64dd7dfacb9c Mon Sep 17 00:00:00 2001 From: Miyoung Date: Mon, 25 Jul 2022 18:46:17 -0700 Subject: [PATCH 129/526] documentation: smdmp v1.10 release note (#3244) --- .../smd_model_parallel_change_log.rst | 56 +++++++++++++++++-- 1 file changed, 50 insertions(+), 6 deletions(-) diff --git a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst index d65efd5022..12ed10049a 100644 --- a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst +++ b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst @@ -5,14 +5,31 @@ Release Notes New features, bug fixes, and improvements are regularly made to the SageMaker distributed model parallel library. -SageMaker Distributed Model Parallel 1.9.0 Release Notes -======================================================== +SageMaker Distributed Model Parallel 1.10.0 Release Notes +========================================================= -*Date: May. 3. 2022* +*Date: July. 19. 2022* -**Currency Updates** +**New Features** -* Added support for PyTorch 1.11.0 +The following new features are added for PyTorch. + +* Added support for FP16 training by implementing smdistributed.modelparallel + modification of Apex FP16_Module and FP16_Optimizer. To learn more, see + `FP16 Training with Model Parallelism + `_. +* New checkpoint APIs for CPU memory usage optimization. To learn more, see + `Checkpointing Distributed Models and Optimizer States + `_. + +**Improvements** + +* The SageMaker distributed model parallel library manages and optimizes CPU + memory by garbage-collecting non-local parameters in general and during checkpointing. +* Changes in the `GPT-2 translate functions + `_ + (``smdistributed.modelparallel.torch.nn.huggingface.gpt2``) + to save memory by not maintaining two copies of weights at the same time. **Migration to AWS Deep Learning Containers** @@ -28,7 +45,7 @@ Binary file of this version of the library for custom container users: .. code:: - https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.11.0/build-artifacts/2022-04-20-17-05/smdistributed_modelparallel-1.9.0-cp38-cp38-linux_x86_64.whl + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.11.0/build-artifacts/2022-07-11-19-23/smdistributed_modelparallel-1.10.0-cp38-cp38-linux_x86_64.whl @@ -37,6 +54,33 @@ Binary file of this version of the library for custom container users: Release History =============== +SageMaker Distributed Model Parallel 1.9.0 Release Notes +-------------------------------------------------------- + +*Date: May. 3. 2022* + +**Currency Updates** + +* Added support for PyTorch 1.11.0 + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- PyTorch 1.11.0 DLC + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.11.0-gpu-py38-cu113-ubuntu20.04-sagemaker + +Binary file of this version of the library for custom container users: + + .. code:: + + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.11.0/build-artifacts/2022-04-20-17-05/smdistributed_modelparallel-1.9.0-cp38-cp38-linux_x86_64.whl + + + SageMaker Distributed Model Parallel 1.8.1 Release Notes -------------------------------------------------------- From 1393930679a1599b854a19f6dc2b06ffc709bc44 Mon Sep 17 00:00:00 2001 From: Miyoung Date: Mon, 25 Jul 2022 18:49:08 -0700 Subject: [PATCH 130/526] documentation: heterogeneous cluster api doc fix (#3248) --- src/sagemaker/inputs.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 855488d33a..3481c138bd 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -67,10 +67,16 @@ def __init__( AugmentedManifestFile formats are described at `S3DataSource `_ in the `Amazon SageMaker API reference`. - instance_groups (list[str]): Optional. A list of ``instance_group_name``\ s - of a heterogeneous cluster that's configured using the + instance_groups (list[str]): Optional. A list of instance group names in string format + that you specified while configuring a heterogeneous cluster using the :class:`sagemaker.instance_group.InstanceGroup`. S3 data will be sent to all instance groups in the specified list. + For instructions on how to use InstanceGroup objects + to configure a heterogeneous cluster + through the SageMaker generic and framework estimator classes, see + `Train Using a Heterogeneous Cluster + `_ + in the *Amazon SageMaker developer guide*. (default: None) input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will use the input mode defined on From 8a4c5d111767bc726ea3e10ba1b71fe540485076 Mon Sep 17 00:00:00 2001 From: Radhika Bhat <78102284+RadhikaB-97@users.noreply.github.com> Date: Tue, 26 Jul 2022 16:36:34 -0700 Subject: [PATCH 131/526] feature: Add CGK region to frameworks by DLC (#3232) --- src/sagemaker/image_uri_config/autogluon.json | 8 +++ .../huggingface-training-compiler.json | 2 + .../image_uri_config/huggingface.json | 31 +++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ src/sagemaker/image_uri_config/pytorch.json | 28 ++++++++ .../image_uri_config/tensorflow.json | 65 +++++++++++++++++++ 6 files changed, 147 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 40d7e1fdfb..505f1d1f7e 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -16,6 +16,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", @@ -45,6 +46,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", @@ -74,6 +76,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", @@ -103,6 +106,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", @@ -140,6 +144,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -172,6 +177,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -204,6 +210,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -236,6 +243,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index abc5e2391e..1b4c6e3e71 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -50,6 +50,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", @@ -78,6 +79,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "eu-central-1": "763104351884", "eu-north-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 48409540c6..317c17030a 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -27,6 +27,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -58,6 +59,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -95,6 +97,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -126,6 +129,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -165,6 +169,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -197,6 +202,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -229,6 +235,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -261,6 +268,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -301,6 +309,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -333,6 +342,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -365,6 +375,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -397,6 +408,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -435,6 +447,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -467,6 +480,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -505,6 +519,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -537,6 +552,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -575,6 +591,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -607,6 +624,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -658,6 +676,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -690,6 +709,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -722,6 +742,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -762,6 +783,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -794,6 +816,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -826,6 +849,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -858,6 +882,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -896,6 +921,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -928,6 +954,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -966,6 +993,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -998,6 +1026,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1036,6 +1065,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1068,6 +1098,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 205d031b66..12bc40fccf 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -234,6 +234,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -265,6 +266,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -296,6 +298,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -327,6 +330,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -358,6 +362,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -616,6 +621,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -647,6 +653,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -678,6 +685,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -709,6 +717,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -740,6 +749,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -844,6 +854,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -875,6 +886,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -906,6 +918,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index b239d9e00c..1bea5ffc30 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -19,6 +19,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -38,6 +39,7 @@ "registries": { "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", + "ap-southeast-3": "907027046896", "eu-west-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -180,6 +182,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -214,6 +217,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -247,6 +251,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -280,6 +285,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -314,6 +320,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -348,6 +355,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -382,6 +390,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -416,6 +425,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -449,6 +459,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -482,6 +493,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -515,6 +527,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -548,6 +561,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -581,6 +595,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -736,6 +751,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -770,6 +786,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -804,6 +821,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -837,6 +855,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -871,6 +890,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -905,6 +925,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -939,6 +960,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -973,6 +995,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1006,6 +1029,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1039,6 +1063,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1072,6 +1097,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1105,6 +1131,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1138,6 +1165,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 564c71ae22..6a2318ddbe 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -143,6 +143,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -173,6 +174,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -203,6 +205,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -233,6 +236,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -384,6 +388,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -414,6 +419,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -444,6 +450,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -474,6 +481,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -504,6 +512,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -534,6 +543,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -564,6 +574,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -786,6 +797,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -816,6 +828,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -846,6 +859,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -876,6 +890,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -906,6 +921,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -936,6 +952,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -966,6 +983,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -996,6 +1014,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1026,6 +1045,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1056,6 +1076,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1086,6 +1107,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1116,6 +1138,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1146,6 +1169,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1176,6 +1200,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1206,6 +1231,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1236,6 +1262,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1266,6 +1293,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1296,6 +1324,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1326,6 +1355,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1356,6 +1386,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1386,6 +1417,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1416,6 +1448,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1606,6 +1639,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1641,6 +1675,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1675,6 +1710,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1710,6 +1746,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1745,6 +1782,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1780,6 +1818,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -1815,6 +1854,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2041,6 +2081,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2075,6 +2116,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2109,6 +2151,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2142,6 +2185,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2175,6 +2219,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2209,6 +2254,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2243,6 +2289,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2276,6 +2323,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2309,6 +2357,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2342,6 +2391,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2375,6 +2425,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2408,6 +2459,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2441,6 +2493,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2474,6 +2527,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2507,6 +2561,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2540,6 +2595,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2573,6 +2629,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2606,6 +2663,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2639,6 +2697,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2672,6 +2731,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2705,6 +2765,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2738,6 +2799,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2771,6 +2833,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2804,6 +2867,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", @@ -2837,6 +2901,7 @@ "ap-south-1": "763104351884", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", From fb5344c824c6ee5fef43919dab2e5080afeff1db Mon Sep 17 00:00:00 2001 From: Samsara Counts Date: Tue, 26 Jul 2022 17:53:39 -0700 Subject: [PATCH 132/526] feature: support clarify bias detection when facets not included (#3221) --- src/sagemaker/clarify.py | 163 ++++++++++--- tests/integ/test_clarify.py | 467 ++++++++++++++++++++++++++++++++++++ tests/unit/test_clarify.py | 106 ++++++++ 3 files changed, 700 insertions(+), 36 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 24fe1f0a48..873a87ca57 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -45,6 +45,12 @@ def __init__( dataset_type="text/csv", s3_compression_type="None", joinsource=None, + facet_dataset_uri=None, + facet_headers=None, + predicted_label_dataset_uri=None, + predicted_label_headers=None, + predicted_label=None, + excluded_columns=None, ): """Initializes a configuration of both input and output datasets. @@ -54,22 +60,57 @@ def __init__( s3_analysis_config_output_path (str): S3 prefix to store the analysis config output. If this field is None, then the ``s3_output_path`` will be used to store the ``analysis_config`` output. - label (str): Target attribute of the model **required** for bias metrics (both pre- - and post-training). Optional when running SHAP explainability. - Specified as column name or index for CSV dataset, or as JSONPath for JSONLines. - headers (list[str]): A list of column names in the input dataset. + label (str): Target attribute of the model required by bias metrics. + Specified as column name or index for CSV dataset or as JSONPath for JSONLines. + *Required parameter* except for when the input dataset does not contain the label. + Cannot be used at the same time as ``predicted_label``. features (str): JSONPath for locating the feature columns for bias metrics if the dataset format is JSONLines. dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, ``"application/jsonlines"`` for JSONLines, and ``"application/x-parquet"`` for Parquet. s3_compression_type (str): Valid options are "None" or ``"Gzip"``. - joinsource (str): The name or index of the column in the dataset that acts as an - identifier column (for instance, while performing a join). This column is only - used as an identifier, and not used for any other computations. This is an - optional field in all cases except when the dataset contains more than one file, - and ``save_local_shap_values`` is set to True - in :class:`~sagemaker.clarify.SHAPConfig`. + joinsource (str or int): The name or index of the column in the dataset that + acts as an identifier column (for instance, while performing a join). + This column is only used as an identifier, and not used for any other computations. + This is an optional field in all cases except: + + * The dataset contains more than one file and `save_local_shap_values` + is set to true in :class:`~sagemaker.clarify.ShapConfig`, and/or + * When the dataset and/or facet dataset and/or predicted label dataset + are in separate files. + + facet_dataset_uri (str): Dataset S3 prefix/object URI that contains facet attribute(s), + used for bias analysis on datasets without facets. + + * If the dataset and the facet dataset are one single file each, then + the original dataset and facet dataset must have the same number of rows. + * If the dataset and facet dataset are in multiple files (either one), then + an index column, ``joinsource``, is required to join the two datasets. + + Clarify will not use the ``joinsource`` column and columns present in the facet + dataset when calling model inference APIs. + facet_headers (list[str]): List of column names in the facet dataset. + predicted_label_dataset_uri (str): Dataset S3 prefix/object URI with predicted labels, + which are used directly for analysis instead of making model inference API calls. + + * If the dataset and the predicted label dataset are one single file each, then the + original dataset and predicted label dataset must have the same number of rows. + * If the dataset and predicted label dataset are in multiple files (either one), + then an index column, ``joinsource``, is required to join the two datasets. + + predicted_label_headers (list[str]): List of column names in the predicted label dataset + predicted_label (str or int): Predicted label of the target attribute of the model + required for running bias analysis. Specified as column name or index for CSV data. + Clarify uses the predicted labels directly instead of making model inference API + calls. Cannot be used at the same time as ``label``. + excluded_columns (list[int] or list[str]): A list of names or indices of the columns + which are to be excluded from making model inference API calls. + + Raises: + ValueError: when the ``dataset_type`` is invalid, predicted label dataset parameters + are used with un-supported ``dataset_type``, or facet dataset parameters + are used with un-supported ``dataset_type`` """ if dataset_type not in [ "text/csv", @@ -81,6 +122,32 @@ def __init__( f"Invalid dataset_type '{dataset_type}'." f" Please check the API documentation for the supported dataset types." ) + # parameters for analysis on datasets without facets are only supported for CSV datasets + if dataset_type != "text/csv": + if predicted_label: + raise ValueError( + f"The parameter 'predicted_label' is not supported" + f" for dataset_type '{dataset_type}'." + f" Please check the API documentation for the supported dataset types." + ) + if excluded_columns: + raise ValueError( + f"The parameter 'excluded_columns' is not supported" + f" for dataset_type '{dataset_type}'." + f" Please check the API documentation for the supported dataset types." + ) + if facet_dataset_uri or facet_headers: + raise ValueError( + f"The parameters 'facet_dataset_uri' and 'facet_headers'" + f" are not supported for dataset_type '{dataset_type}'." + f" Please check the API documentation for the supported dataset types." + ) + if predicted_label_dataset_uri or predicted_label_headers: + raise ValueError( + f"The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers'" + f" are not supported for dataset_type '{dataset_type}'." + f" Please check the API documentation for the supported dataset types." + ) self.s3_data_input_path = s3_data_input_path self.s3_output_path = s3_output_path self.s3_analysis_config_output_path = s3_analysis_config_output_path @@ -89,6 +156,12 @@ def __init__( self.label = label self.headers = headers self.features = features + self.facet_dataset_uri = facet_dataset_uri + self.facet_headers = facet_headers + self.predicted_label_dataset_uri = predicted_label_dataset_uri + self.predicted_label_headers = predicted_label_headers + self.predicted_label = predicted_label + self.excluded_columns = excluded_columns self.analysis_config = { "dataset_type": dataset_type, } @@ -96,6 +169,12 @@ def __init__( _set(headers, "headers", self.analysis_config) _set(label, "label", self.analysis_config) _set(joinsource, "joinsource_name_or_index", self.analysis_config) + _set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config) + _set(facet_headers, "facet_headers", self.analysis_config) + _set(predicted_label_dataset_uri, "predicted_label_dataset_uri", self.analysis_config) + _set(predicted_label_headers, "predicted_label_headers", self.analysis_config) + _set(predicted_label, "predicted_label", self.analysis_config) + _set(excluded_columns, "excluded_columns", self.analysis_config) def get_config(self): """Returns part of an analysis config dictionary.""" @@ -205,21 +284,23 @@ def __init__( r"""Initializes a configuration of a model and the endpoint to be created for it. Args: - model_name (str): Model name (as created by 'CreateModel'). + model_name (str): Model name (as created by + `CreateModel `_. instance_count (int): The number of instances of a new endpoint for model inference. - instance_type (str): The type of EC2 instance to use for model inference, - for example, ``"ml.c5.xlarge"``. + instance_type (str): The type of + `EC2 instance `_ + to use for model inference; for example, ``"ml.c5.xlarge"``. accept_type (str): The model output format to be used for getting inferences with the - shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines". - Default is the same as content_type. + shadow endpoint. Valid values are ``"text/csv"`` for CSV and + ``"application/jsonlines"``. Default is the same as ``content_type``. content_type (str): The model input format to be used for getting inferences with the - shadow endpoint. Valid values are "text/csv" for CSV and "application/jsonlines". - Default is the same as dataset format. + shadow endpoint. Valid values are ``"text/csv"`` for CSV and + ``"application/jsonlines"``. Default is the same as ``dataset_format``. content_template (str): A template string to be used to construct the model input from dataset instances. It is only used when ``model_content_type`` is ``"application/jsonlines"``. The template should have one and only one placeholder, - "features", which will be replaced by a features list to form the model inference - input. + ``"features"``, which will be replaced by a features list to form the model + inference input. custom_attributes (str): Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this @@ -509,16 +590,20 @@ def __init__( for these units. language (str): Specifies the language of the text features. Accepted values are one of the following: - "chinese", "danish", "dutch", "english", "french", "german", "greek", "italian", - "japanese", "lithuanian", "multi-language", "norwegian bokmål", "polish", - "portuguese", "romanian", "russian", "spanish", "afrikaans", "albanian", "arabic", - "armenian", "basque", "bengali", "bulgarian", "catalan", "croatian", "czech", - "estonian", "finnish", "gujarati", "hebrew", "hindi", "hungarian", "icelandic", - "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian", "luxembourgish", - "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit", "serbian", - "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil", - "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". - Use "multi-language" for a mix of multiple languages. + ``"chinese"``, ``"danish"``, ``"dutch"``, ``"english"``, ``"french"``, ``"german"``, + ``"greek"``, ``"italian"``, ``"japanese"``, ``"lithuanian"``, ``"multi-language"``, + ``"norwegian bokmål"``, ``"polish"``, ``"portuguese"``, ``"romanian"``, + ``"russian"``, ``"spanish"``, ``"afrikaans"``, ``"albanian"``, ``"arabic"``, + ``"armenian"``, ``"basque"``, ``"bengali"``, ``"bulgarian"``, ``"catalan"``, + ``"croatian"``, ``"czech"``, ``"estonian"``, ``"finnish"``, ``"gujarati"``, + ``"hebrew"``, ``"hindi"``, ``"hungarian"``, ``"icelandic"``, ``"indonesian"``, + ``"irish"``, ``"kannada"``, ``"kyrgyz"``, ``"latvian"``, ``"ligurian"``, + ``"luxembourgish"``, ``"macedonian"``, ``"malayalam"``, ``"marathi"``, ``"nepali"``, + ``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``, + ``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``, + ``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``, + ``"vietnamese"``, ``"yoruba"``. + Use ``"multi-language"`` for a mix of multiple languages. Raises: ValueError: when ``granularity`` is not in list of supported values @@ -742,12 +827,15 @@ def __init__( data stored in Amazon S3. instance_count (int): The number of instances to run a processing job with. - instance_type (str): The type of EC2 instance to use for - processing, for example, ``'ml.c4.xlarge'``. - volume_size_in_gb (int): Size in GB of the EBS volume - to use for storing data during processing (default: 30). - volume_kms_key (str): A KMS key for the processing - volume (default: None). + instance_type (str): The type of + `EC2 instance `_ + to use for model inference; for example, ``"ml.c5.xlarge"``. + volume_size_in_gb (int): Size in GB of the + `EBS volume `_. + to use for storing data during processing (default: 30 GB). + volume_kms_key (str): A + `KMS key `_ + for the processing volume (default: None). output_kms_key (str): The KMS key ID for processing job outputs (default: None). max_runtime_in_seconds (int): Timeout in seconds (default: None). After this amount of time, Amazon SageMaker terminates the job, @@ -769,7 +857,7 @@ def __init__( inter-container traffic, security group IDs, and subnets. job_name_prefix (str): Processing job name prefix. version (str): Clarify version to use. - """ + """ # noqa E501 # pylint: disable=c0301 container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) self.job_name_prefix = job_name_prefix super(SageMakerClarifyProcessor, self).__init__( @@ -1163,6 +1251,7 @@ def run_explainability( Currently, only SHAP and Partial Dependence Plots (PDP) are supported as explainability methods. + You can request both methods or one at a time with the ``explainability_config`` parameter. When SHAP is requested in the ``explainability_config``, the SHAP algorithm calculates the feature importance for each input example @@ -1188,6 +1277,8 @@ def run_explainability( Config of the specific explainability method or a list of :class:`~sagemaker.clarify.ExplainabilityConfig` objects. Currently, SHAP and PDP are the two methods supported. + You can request multiple methods at once by passing in a list of + `~sagemaker.clarify.ExplainabilityConfig`. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): Index or JSONPath to locate the predicted scores in the model output. This is not required if the model output is a single score. Alternatively, it can be an instance diff --git a/tests/integ/test_clarify.py b/tests/integ/test_clarify.py index 111c5618af..a107c00859 100644 --- a/tests/integ/test_clarify.py +++ b/tests/integ/test_clarify.py @@ -49,6 +49,60 @@ def training_set(): return features, label +@pytest.fixture(scope="module") +def training_set_5cols(): + label = (np.random.rand(100, 1) > 0.5).astype(np.int32) + features = np.random.rand(100, 5) + return features, label + + +@pytest.fixture(scope="module") +def training_set_no_label(): + features = np.random.rand(100, 2) + return features + + +@pytest.fixture(scope="module") +def training_set_label_index(): + label = (np.random.rand(100, 1) > 0.5).astype(np.int32) + features = np.random.rand(100, 2) + index = np.arange(0, 100) # to be used as joinsource + return features, label, index + + +@pytest.fixture(scope="module") +def facet_dataset_joinsource(): + features = np.random.rand(100, 2) + index = np.arange(0, 100) # to be used as joinsource + return features, index + + +@pytest.fixture(scope="module") +def facet_dataset(): + features = np.random.rand(100, 1) + return features + + +@pytest.fixture(scope="module") +def facet_dataset_joinsource_split_1(): + features = np.random.rand(50, 2) + index = np.arange(0, 50) # to be used as joinsource + return features, index + + +@pytest.fixture(scope="module") +def facet_dataset_joinsource_split_2(): + features = np.random.rand(50, 2) + index = np.arange(50, 100) # to be used as joinsource + return features, index + + +@pytest.fixture(scope="module") +def pred_label_dataset(): + pred_label = (np.random.rand(100, 1) > 0.5).astype(np.int32) + return pred_label + + @pytest.yield_fixture(scope="module") def data_path(training_set): features, label = training_set @@ -59,6 +113,90 @@ def data_path(training_set): yield filename +@pytest.yield_fixture(scope="module") +def data_path_excl_cols(training_set_5cols): + features, label = training_set_5cols + data = pd.concat([pd.DataFrame(label), pd.DataFrame(features)], axis=1, sort=False) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "train.csv") + data.to_csv(filename, index=False, header=False) + yield filename + + +# training data with no label column and joinsource +@pytest.yield_fixture(scope="module") +def data_path_no_label_index(training_set_no_label): + data = pd.DataFrame(training_set_no_label) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "train_no_label_index.csv") + data.to_csv(filename, index=False, header=False) + yield filename + + +# training data with label column & joinsource (index) +@pytest.yield_fixture(scope="module") +def data_path_label_index(training_set_label_index): + features, label, index = training_set_label_index + data = pd.concat( + [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], axis=1, sort=False + ) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "train_label_index.csv") + data.to_csv(filename, index=False, header=False) + yield filename + + +# training data with label column & joinsource (index) +@pytest.yield_fixture(scope="module") +def data_path_label_index_6col(training_set_label_index): + features, label, index = training_set_label_index + data = pd.concat( + [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(features), pd.DataFrame(index)], + axis=1, + sort=False, + ) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "train_label_index_6col.csv") + data.to_csv(filename, index=False, header=False) + yield filename + + +@pytest.yield_fixture(scope="module") +def facet_data_path(facet_dataset_joinsource): + features, index = facet_dataset_joinsource + data = pd.concat([pd.DataFrame(index), pd.DataFrame(features)], axis=1, sort=False) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "facet_with_joinsource.csv") + data.to_csv(filename, index=False, header=False) + yield filename + + +# split facet dataset across 2 files +@pytest.yield_fixture(scope="module") +def facet_data_path_multiple_files( + facet_dataset_joinsource_split_1, facet_dataset_joinsource_split_2 +): + features_1, index_1 = facet_dataset_joinsource_split_1 + data_1 = pd.concat([pd.DataFrame(index_1), pd.DataFrame(features_1)], axis=1, sort=False) + features_2, index_2 = facet_dataset_joinsource_split_2 + data_2 = pd.concat([pd.DataFrame(index_2), pd.DataFrame(features_2)], axis=1, sort=False) + with tempfile.TemporaryDirectory() as tmpdirname: + filename1 = os.path.join(tmpdirname, "facet1.csv") + data_1.to_csv(filename1, index=False, header=False) + filename2 = os.path.join(tmpdirname, "facet2.csv") + data_2.to_csv(filename2, index=False, header=False) + yield filename1, filename2 + + +@pytest.yield_fixture(scope="module") +def pred_data_path(pred_label_dataset, pred_label_headers): + data = pd.DataFrame(pred_label_dataset, columns=pred_label_headers) + with tempfile.TemporaryDirectory() as tmpdirname: + filename = os.path.join(tmpdirname, "predicted_label.csv") + data.to_csv(filename, index=False, header=pred_label_headers) + yield filename + + @pytest.fixture(scope="module") def headers(): return [ @@ -70,6 +208,71 @@ def headers(): ] +@pytest.fixture(scope="module") +def headers_excl_cols(): + return [ + "Label", + "F1", + "F2", + "F3", + "F4", + "F5", + ] + + +@pytest.fixture(scope="module") +def headers_no_label_joinsource(): + return [ + "F3", + "F4", + "Index", + ] + + +@pytest.fixture(scope="module") +def headers_label_joinsource(): + return [ + "Label", + "F3", + "F4", + "Index", + ] + + +@pytest.fixture(scope="module") +def headers_label_joinsource_6col(): + return [ + "Label", + "F3", + "F4", + "F5", + "F6", + "Index", + ] + + +@pytest.fixture(scope="module") +def facet_headers(): + return [ + "F1", + "F2", + ] + + +@pytest.fixture(scope="module") +def facet_headers_joinsource(): + return [ + "Index", + "F1", + "F2", + ] + + +@pytest.fixture(scope="module") +def pred_label_headers(): + return ["PredictedLabel"] + + @pytest.yield_fixture(scope="module") def model_name(sagemaker_session, cpu_instance_type, training_set): job_name = utils.unique_name_from_base("clarify-xgb") @@ -127,6 +330,143 @@ def data_config(sagemaker_session, data_path, headers): ) +# for testing posttraining bias with excluded columns +@pytest.fixture +def data_config_excluded_columns(sagemaker_session, data_path_excl_cols, headers_excl_cols): + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run + ) + return DataConfig( + s3_data_input_path=data_path_excl_cols, + s3_output_path=output_path, + label="Label", + headers=headers_excl_cols, + dataset_type="text/csv", + excluded_columns=["F2"], + ) + + +# dataset config for running analysis with facets not included in input dataset +# (with facets in multiple files), excluded columns, and no predicted_labels (so run inference) +@pytest.fixture +def data_config_facets_not_included_multiple_files( + sagemaker_session, + data_path_label_index_6col, + facet_data_path_multiple_files, + headers_label_joinsource_6col, + facet_headers_joinsource, +): + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run + ) + # upload facet datasets + facet_data_folder_s3_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "linear_learner_analysis_resources", + test_run, + "facets_folder", + ) + facet_data1_s3_uri = facet_data_folder_s3_uri + "/facet1.csv" + facet_data2_s3_uri = facet_data_folder_s3_uri + "/facet2.csv" + facet1, facet2 = facet_data_path_multiple_files + _upload_dataset(facet1, facet_data1_s3_uri, sagemaker_session) + _upload_dataset(facet2, facet_data2_s3_uri, sagemaker_session) + + return DataConfig( + s3_data_input_path=data_path_label_index_6col, + s3_output_path=output_path, + label="Label", + headers=headers_label_joinsource_6col, + dataset_type="text/csv", + joinsource="Index", + facet_dataset_uri=facet_data_folder_s3_uri, + facet_headers=facet_headers_joinsource, + excluded_columns=["F4"], + ) + + +# for testing pretraining bias with facets not included +@pytest.fixture +def data_config_facets_not_included( + sagemaker_session, + data_path_label_index, + facet_data_path, + headers_label_joinsource, + facet_headers_joinsource, +): + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run + ) + # upload facet dataset + facet_data_s3_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "linear_learner_analysis_resources", + test_run, + "facet_with_joinsource.csv", + ) + _upload_dataset(facet_data_path, facet_data_s3_uri, sagemaker_session) + return DataConfig( + s3_data_input_path=data_path_label_index, + s3_output_path=output_path, + label="Label", + headers=headers_label_joinsource, + dataset_type="text/csv", + joinsource="Index", + facet_dataset_uri=facet_data_s3_uri, + facet_headers=facet_headers_joinsource, + ) + + +# for testing posttraining bias with facets not included +# and separate predicted label dataset +# no excluded_columns (does not make calls to model inference API) +@pytest.fixture +def data_config_facets_not_included_pred_labels( + sagemaker_session, + data_path_no_label_index, + facet_data_path, + pred_data_path, + headers_no_label_joinsource, + facet_headers, + pred_label_headers, +): + test_run = utils.unique_name_from_base("test_run") + output_path = "s3://{}/{}/{}".format( + sagemaker_session.default_bucket(), "linear_learner_analysis_result", test_run + ) + # upload facet dataset for testing + facet_data_s3_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "linear_learner_analysis_resources", + test_run, + "facet_with_joinsource.csv", + ) + _upload_dataset(facet_data_path, facet_data_s3_uri, sagemaker_session) + # upload predicted_labels dataset for testing + pred_label_data_s3_uri = "s3://{}/{}/{}/{}".format( + sagemaker_session.default_bucket(), + "linear_learner_analysis_resources", + test_run, + "predicted_labels_with_joinsource.csv", + ) + _upload_dataset(pred_data_path, pred_label_data_s3_uri, sagemaker_session) + return DataConfig( + s3_data_input_path=data_path_no_label_index, + s3_output_path=output_path, + headers=headers_no_label_joinsource, + dataset_type="text/csv", + joinsource="Index", + facet_dataset_uri=facet_data_s3_uri, + facet_headers=facet_headers, + predicted_label_dataset_uri=pred_label_data_s3_uri, + predicted_label_headers=pred_label_headers, + predicted_label=0, + ) + + @pytest.fixture(scope="module") def data_bias_config(): return BiasConfig( @@ -137,6 +477,15 @@ def data_bias_config(): ) +@pytest.fixture(scope="module") +def data_bias_config_excluded_columns(): + return BiasConfig( + label_values_or_threshold=[1], + facet_name="F1", + facet_values_or_threshold=[0.5], + ) + + @pytest.fixture(scope="module") def model_config(model_name): return ModelConfig( @@ -201,6 +550,34 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag check_analysis_config(data_config, sagemaker_session, "pre_training_bias") +def test_pre_training_bias_facets_not_included( + clarify_processor, data_config_facets_not_included, data_bias_config, sagemaker_session +): + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_pre_training_bias( + data_config_facets_not_included, + data_bias_config, + job_name=utils.unique_name_from_base("clarify-pretraining-bias-facets-not-included"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config_facets_not_included.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["pre_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config( + data_config_facets_not_included, sagemaker_session, "pre_training_bias" + ) + + def test_post_training_bias( clarify_processor, data_config, @@ -234,6 +611,75 @@ def test_post_training_bias( check_analysis_config(data_config, sagemaker_session, "post_training_bias") +# run posttraining bias with no predicted labels provided, so make calls to model inference API +def test_post_training_bias_facets_not_included_excluded_columns( + clarify_processor, + data_config_facets_not_included_multiple_files, + data_bias_config, + model_config, + model_predicted_label_config, + sagemaker_session, +): + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_post_training_bias( + data_config_facets_not_included_multiple_files, + data_bias_config, + model_config, + model_predicted_label_config, + job_name=utils.unique_name_from_base("clarify-posttraining-bias-excl-cols-facets-sep"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config_facets_not_included_multiple_files.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config( + data_config_facets_not_included_multiple_files, sagemaker_session, "post_training_bias" + ) + + +def test_post_training_bias_excluded_columns( + clarify_processor, + data_config_excluded_columns, + data_bias_config_excluded_columns, + model_config, + model_predicted_label_config, + sagemaker_session, +): + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_post_training_bias( + data_config_excluded_columns, + data_bias_config_excluded_columns, + model_config, + model_predicted_label_config, + job_name=utils.unique_name_from_base("clarify-posttraining-bias-excl-cols"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config_excluded_columns.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config(data_config_excluded_columns, sagemaker_session, "post_training_bias") + + def test_shap(clarify_processor, data_config, model_config, shap_config, sagemaker_session): with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): clarify_processor.run_explainability( @@ -265,3 +711,24 @@ def check_analysis_config(data_config, sagemaker_session, method): ) analysis_config = json.loads(analysis_config_json) assert method in analysis_config["methods"] + + +def _upload_dataset(dataset_local_path, s3_dataset_path, sagemaker_session): + """Upload dataset (intended for facet or predicted labels dataset, not training dataset) to S3 + + Args: + dataset_local_path (str): File path to the local analysis config file. + s3_dataset_path (str): S3 prefix to store the analysis config file. + sagemaker_session (:class:`~sagemaker.session.Session`): + Session object which manages interactions with Amazon SageMaker and + any other AWS services needed. If not specified, the processor creates + one using the default AWS configuration chain. + + Returns: + The S3 uri of the uploaded dataset. + """ + return s3.S3Uploader.upload( + local_path=dataset_local_path, + desired_s3_uri=s3_dataset_path, + sagemaker_session=sagemaker_session, + ) diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 1e3ae47f63..fa437573f0 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -42,6 +42,7 @@ def test_uri(): def test_data_config(): + # facets in input dataset s3_data_input_path = "s3://path/to/input.csv" s3_output_path = "s3://path/to/output" label_name = "Label" @@ -66,20 +67,125 @@ def test_data_config(): "headers": headers, "label": "Label", } + assert expected_config == data_config.get_config() assert s3_data_input_path == data_config.s3_data_input_path assert s3_output_path == data_config.s3_output_path assert "None" == data_config.s3_compression_type assert "FullyReplicated" == data_config.s3_data_distribution_type + # facets NOT in input dataset + joinsource = 5 + facet_dataset_uri = "s3://path/to/facet.csv" + facet_headers = ["Age"] + predicted_label_dataset_uri = "s3://path/to/pred.csv" + predicted_label_headers = ["Label", "F1", "F2", "F3", "F4", "Age"] + predicted_label = "predicted_label" + excluded_columns = "F4" + + data_config_no_facet = DataConfig( + s3_data_input_path=s3_data_input_path, + s3_output_path=s3_output_path, + label=label_name, + headers=headers, + dataset_type=dataset_type, + joinsource=joinsource, + facet_dataset_uri=facet_dataset_uri, + facet_headers=facet_headers, + predicted_label_dataset_uri=predicted_label_dataset_uri, + predicted_label_headers=predicted_label_headers, + predicted_label=predicted_label, + excluded_columns=excluded_columns, + ) + + expected_config_no_facet = { + "dataset_type": "text/csv", + "headers": headers, + "label": label_name, + "joinsource_name_or_index": joinsource, + "facet_dataset_uri": facet_dataset_uri, + "facet_headers": facet_headers, + "predicted_label_dataset_uri": predicted_label_dataset_uri, + "predicted_label_headers": predicted_label_headers, + "predicted_label": predicted_label, + "excluded_columns": excluded_columns, + } + + assert expected_config_no_facet == data_config_no_facet.get_config() + assert joinsource == data_config_no_facet.analysis_config["joinsource_name_or_index"] + assert facet_dataset_uri == data_config_no_facet.facet_dataset_uri + assert facet_headers == data_config_no_facet.facet_headers + assert predicted_label_dataset_uri == data_config_no_facet.predicted_label_dataset_uri + assert predicted_label_headers == data_config_no_facet.predicted_label_headers + assert predicted_label == data_config_no_facet.predicted_label + + excluded_columns = "F4" + data_config_excluded_cols = DataConfig( + s3_data_input_path=s3_data_input_path, + s3_output_path=s3_output_path, + label=label_name, + headers=headers, + dataset_type=dataset_type, + joinsource=joinsource, + excluded_columns=excluded_columns, + ) + + expected_config_excluded_cols = { + "dataset_type": "text/csv", + "headers": headers, + "label": label_name, + "joinsource_name_or_index": joinsource, + "excluded_columns": excluded_columns, + } + + assert expected_config_excluded_cols == data_config_excluded_cols.get_config() + assert joinsource == data_config_excluded_cols.analysis_config["joinsource_name_or_index"] + assert excluded_columns == data_config_excluded_cols.excluded_columns + def test_invalid_data_config(): + # facets included in input dataset with pytest.raises(ValueError, match=r"^Invalid dataset_type"): DataConfig( s3_data_input_path="s3://bucket/inputpath", s3_output_path="s3://bucket/outputpath", dataset_type="whatnot_type", ) + # facets NOT included in input dataset + error_msg = r"^The parameter 'predicted_label' is not supported for dataset_type" + with pytest.raises(ValueError, match=error_msg): + DataConfig( + s3_data_input_path="s3://bucket/inputpath", + s3_output_path="s3://bucket/outputpath", + dataset_type="application/x-parquet", + predicted_label="label", + ) + error_msg = r"^The parameter 'excluded_columns' is not supported for dataset_type" + with pytest.raises(ValueError, match=error_msg): + DataConfig( + s3_data_input_path="s3://bucket/inputpath", + s3_output_path="s3://bucket/outputpath", + dataset_type="application/x-image", + excluded_columns="excluded", + ) + error_msg = r"^The parameters 'facet_dataset_uri' and 'facet_headers' are not supported for dataset_type" # noqa E501 # pylint: disable=c0301 + with pytest.raises(ValueError, match=error_msg): + DataConfig( + s3_data_input_path="s3://bucket/inputpath", + s3_output_path="s3://bucket/outputpath", + dataset_type="application/x-image", + facet_dataset_uri="facet_dataset/URI", + facet_headers="facet", + ) + error_msg = r"^The parameters 'predicted_label_dataset_uri' and 'predicted_label_headers' are not supported for dataset_type" # noqa E501 # pylint: disable=c0301 + with pytest.raises(ValueError, match=error_msg): + DataConfig( + s3_data_input_path="s3://bucket/inputpath", + s3_output_path="s3://bucket/outputpath", + dataset_type="application/jsonlines", + predicted_label_dataset_uri="pred_dataset/URI", + predicted_label_headers="prediction", + ) def test_s3_data_distribution_type_ignorance(): From 1f5468429f50ad0782a9c6cec06698f502900e8c Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh <105655261+rahven14@users.noreply.github.com> Date: Wed, 27 Jul 2022 10:33:31 +0530 Subject: [PATCH 133/526] feat: infer framework and version (#3247) --- src/sagemaker/chainer/model.py | 102 ++++++++++++++++++ src/sagemaker/huggingface/model.py | 20 +++- src/sagemaker/model.py | 4 +- src/sagemaker/mxnet/model.py | 4 +- src/sagemaker/pipeline.py | 4 +- src/sagemaker/pytorch/model.py | 4 +- src/sagemaker/sklearn/model.py | 4 +- src/sagemaker/tensorflow/model.py | 4 +- src/sagemaker/utils.py | 62 +++++++---- src/sagemaker/workflow/step_collections.py | 2 +- src/sagemaker/xgboost/model.py | 102 ++++++++++++++++++ tests/unit/sagemaker/tensorflow/test_tfs.py | 50 ++++++++- .../test_huggingface_pytorch_compiler.py | 59 +++++++++- .../test_huggingface_tensorflow_compiler.py | 59 +++++++++- tests/unit/test_chainer.py | 56 +++++++++- tests/unit/test_mxnet.py | 53 +++++++++ tests/unit/test_pytorch.py | 54 ++++++++++ tests/unit/test_sklearn.py | 48 ++++++++- tests/unit/test_xgboost.py | 52 +++++++++ 19 files changed, 700 insertions(+), 43 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 41e3a6e838..3f22e22d5d 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -140,6 +140,108 @@ def __init__( self.model_server_workers = model_server_workers + def register( + self, + content_types, + response_types, + inference_instances, + transform_instances, + model_package_name=None, + model_package_group_name=None, + image_uri=None, + model_metrics=None, + metadata_properties=None, + marketplace_cert=False, + approval_status=None, + description=None, + drift_check_baselines=None, + customer_metadata_properties=None, + domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, + ): + """Creates a model package for creating SageMaker models or listing on Marketplace. + + Args: + content_types (list): The supported MIME types for the input data. + response_types (list): The supported MIME types for the output data. + inference_instances (list): A list of the instance types that are used to + generate inferences in real-time. + transform_instances (list): A list of the instance types on which a transformation + job can be run or on which an endpoint can be deployed. + model_package_name (str): Model Package name, exclusive to `model_package_group_name`, + using `model_package_name` makes the Model Package un-versioned (default: None). + model_package_group_name (str): Model Package Group name, exclusive to + `model_package_name`, using `model_package_group_name` makes the Model Package + versioned (default: None). + image_uri (str): Inference image uri for the container. Model class' self.image will + be used if it is None (default: None). + model_metrics (ModelMetrics): ModelMetrics object (default: None). + metadata_properties (MetadataProperties): MetadataProperties (default: None). + marketplace_cert (bool): A boolean value indicating if the Model Package is certified + for AWS Marketplace (default: False). + approval_status (str): Model Approval Status, values can be "Approved", "Rejected", + or "PendingManualApproval" (default: "PendingManualApproval"). + description (str): Model Package description (default: None). + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", + "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + + Returns: + str: A string of SageMaker Model Package ARN. + """ + instance_type = inference_instances[0] + self._init_sagemaker_session_if_does_not_exist(instance_type) + + if image_uri: + self.image_uri = image_uri + if not self.image_uri: + self.image_uri = self.serving_image_uri( + region_name=self.sagemaker_session.boto_session.region_name, + instance_type=instance_type, + ) + return super(ChainerModel, self).register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_name, + model_package_group_name, + image_uri, + model_metrics, + metadata_properties, + marketplace_cert, + approval_status, + description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + ) + def prepare_container_def( self, instance_type=None, accelerator_type=None, serverless_inference_config=None ): diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 8814b72175..04af57b566 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -85,6 +85,14 @@ def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri): ) +def fetch_framework_and_framework_version(tensorflow_version, pytorch_version): + """Function to check the framework used in HuggingFace class""" + + if tensorflow_version is not None: # pylint: disable=no-member + return ("tensorflow", tensorflow_version) # pylint: disable=no-member + return ("pytorch", pytorch_version) # pylint: disable=no-member + + class HuggingFaceModel(FrameworkModel): """A Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``.""" @@ -387,8 +395,16 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=framework, - framework_version=framework_version, + framework=( + framework + or fetch_framework_and_framework_version( + self.tensorflow_version, self.pytorch_version + )[0] + ).upper(), + framework_version=framework_version + or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[ + 1 + ], nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, ) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 8f128fe3f4..704e3385fd 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -374,12 +374,12 @@ def register( if model_package_group_name is not None: container_def = self.prepare_container_def() - update_container_with_inference_params( + container_def = update_container_with_inference_params( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, - container_obj=container_def, + container_def=container_def, ) else: container_def = { diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 60fc1d60d2..4aaf6a8acc 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -238,8 +238,8 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=framework, - framework_version=framework_version, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, ) diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 8cdb82ffe7..5047e6351a 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -340,12 +340,12 @@ def register( container_def = self.pipeline_container_def( inference_instances[0] if inference_instances else None ) - update_container_with_inference_params( + container_def = update_container_with_inference_params( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, - container_list=container_def, + container_def=container_def, ) else: container_def = [ diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index b5e019f492..fcbfd1da84 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -239,8 +239,8 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=framework, - framework_version=framework_version, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 67f9d60175..70ea22908e 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -233,8 +233,8 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=framework, - framework_version=framework_version, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, ) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index e5e6798a63..c910e85f20 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -285,8 +285,8 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=framework, - framework_version=framework_version, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, ) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 65de8981aa..1998525a98 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -739,7 +739,7 @@ def update_container_with_inference_params( framework_version=None, nearest_model_name=None, data_input_configuration=None, - container_obj=None, + container_def=None, container_list=None, ): """Function to check if inference recommender parameters exist and update container. @@ -752,28 +752,30 @@ def update_container_with_inference_params( nearest_model_name (str): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str): Input object for the model (default: None). - container_obj (dict): object to be updated. + container_def (dict): object to be updated. container_list (list): list to be updated. Returns: dict: dict with inference recommender params """ - if framework is not None and framework_version is not None and nearest_model_name is not None: - if container_list is not None: - for obj in container_list: - construct_container_object( - obj, data_input_configuration, framework, framework_version, nearest_model_name - ) - if container_obj is not None: + if container_list is not None: + for obj in container_list: construct_container_object( - container_obj, - data_input_configuration, - framework, - framework_version, - nearest_model_name, + obj, data_input_configuration, framework, framework_version, nearest_model_name ) + if container_def is not None: + construct_container_object( + container_def, + data_input_configuration, + framework, + framework_version, + nearest_model_name, + ) + + return container_list or container_def + def construct_container_object( obj, data_input_configuration, framework, framework_version, nearest_model_name @@ -788,20 +790,32 @@ def construct_container_object( nearest_model_name (str): Name of a pre-trained machine learning benchmarked by Amazon SageMaker Inference Recommender (default: None). data_input_configuration (str): Input object for the model (default: None). - container_obj (dict): object to be updated. - container_list (list): list to be updated. + obj (dict): object to be updated. Returns: dict: container object """ - obj.update( - { - "Framework": framework, - "FrameworkVersion": framework_version, - "NearestModelName": nearest_model_name, - } - ) + if framework is not None: + obj.update( + { + "Framework": framework, + } + ) + + if framework_version is not None: + obj.update( + { + "FrameworkVersion": framework_version, + } + ) + + if nearest_model_name is not None: + obj.update( + { + "NearestModelName": nearest_model_name, + } + ) if data_input_configuration is not None: obj.update( @@ -811,3 +825,5 @@ def construct_container_object( }, } ) + + return obj diff --git a/src/sagemaker/workflow/step_collections.py b/src/sagemaker/workflow/step_collections.py index 270b838164..1f85d56442 100644 --- a/src/sagemaker/workflow/step_collections.py +++ b/src/sagemaker/workflow/step_collections.py @@ -250,7 +250,7 @@ def __init__( ) ] - update_container_with_inference_params( + self.container_def_list = update_container_with_inference_params( framework=framework, framework_version=framework_version, nearest_model_name=nearest_model_name, diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 2b90eea0f2..6e56230234 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -124,6 +124,108 @@ def __init__( validate_py_version(py_version) validate_framework_version(framework_version) + def register( + self, + content_types, + response_types, + inference_instances=None, + transform_instances=None, + model_package_name=None, + model_package_group_name=None, + image_uri=None, + model_metrics=None, + metadata_properties=None, + marketplace_cert=False, + approval_status=None, + description=None, + drift_check_baselines=None, + customer_metadata_properties=None, + domain=None, + sample_payload_url=None, + task=None, + framework=None, + framework_version=None, + nearest_model_name=None, + data_input_configuration=None, + ): + """Creates a model package for creating SageMaker models or listing on Marketplace. + + Args: + content_types (list): The supported MIME types for the input data. + response_types (list): The supported MIME types for the output data. + inference_instances (list): A list of the instance types that are used to + generate inferences in real-time. + transform_instances (list): A list of the instance types on which a transformation + job can be run or on which an endpoint can be deployed. + model_package_name (str): Model Package name, exclusive to `model_package_group_name`, + using `model_package_name` makes the Model Package un-versioned (default: None). + model_package_group_name (str): Model Package Group name, exclusive to + `model_package_name`, using `model_package_group_name` makes the Model Package + versioned (default: None). + image_uri (str): Inference image uri for the container. Model class' self.image will + be used if it is None (default: None). + model_metrics (ModelMetrics): ModelMetrics object (default: None). + metadata_properties (MetadataProperties): MetadataProperties (default: None). + marketplace_cert (bool): A boolean value indicating if the Model Package is certified + for AWS Marketplace (default: False). + approval_status (str): Model Approval Status, values can be "Approved", "Rejected", + or "PendingManualApproval" (default: "PendingManualApproval"). + description (str): Model Package description (default: None). + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). + customer_metadata_properties (dict[str, str]): A dictionary of key-value paired + metadata properties (default: None). + domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", + "MACHINE_LEARNING" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). + task (str): Task values which are supported by Inference Recommender are "FILL_MASK", + "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", + "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + framework (str): Machine learning framework of the model package container image + (default: None). + framework_version (str): Framework version of the Model Package Container Image + (default: None). + nearest_model_name (str): Name of a pre-trained machine learning benchmarked by + Amazon SageMaker Inference Recommender (default: None). + data_input_configuration (str): Input object for the model (default: None). + + Returns: + str: A string of SageMaker Model Package ARN. + """ + instance_type = inference_instances[0] + self._init_sagemaker_session_if_does_not_exist(instance_type) + + if image_uri: + self.image_uri = image_uri + if not self.image_uri: + self.image_uri = self.serving_image_uri( + region_name=self.sagemaker_session.boto_session.region_name, + instance_type=instance_type, + ) + return super(XGBoostModel, self).register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_name, + model_package_group_name, + image_uri, + model_metrics, + metadata_properties, + marketplace_cert, + approval_status, + description, + drift_check_baselines=drift_check_baselines, + customer_metadata_properties=customer_metadata_properties, + domain=domain, + sample_payload_url=sample_payload_url, + task=task, + framework=(framework or self._framework_name).upper(), + framework_version=framework_version or self.framework_version, + nearest_model_name=nearest_model_name, + data_input_configuration=data_input_configuration, + ) + def prepare_container_def( self, instance_type=None, accelerator_type=None, serverless_inference_config=None ): diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index 322f2e4379..67b69efc44 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -18,7 +18,7 @@ import mock import pytest -from mock import Mock, patch +from mock import Mock, patch, ANY from sagemaker.serializers import CSVSerializer, IdentitySerializer from sagemaker.tensorflow import TensorFlow, TensorFlowModel, TensorFlowPredictor @@ -454,3 +454,51 @@ def mock_response(expected_response, sagemaker_session, content_type=JSON_CONTEN "ContentType": content_type, "Body": io.BytesIO(expected_response), } + + +def test_register_tfs_model_auto_infer_framework(sagemaker_session, tensorflow_inference_version): + model_package_group_name = "test-tfs-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + tfs_model = TensorFlowModel( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=tensorflow_inference_version, + sagemaker_session=sagemaker_session, + ) + + tfs_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "TENSORFLOW", + "FrameworkVersion": tensorflow_inference_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 87cdd943b0..9bcf0559c5 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -17,10 +17,11 @@ import os import pytest -from mock import MagicMock, Mock, patch +from mock import MagicMock, Mock, patch, ANY from sagemaker import image_uris from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig +from sagemaker.huggingface.model import HuggingFaceModel from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -506,3 +507,59 @@ def test_attach( ) assert estimator.source_dir == "s3://some/sourcedir.tar.gz" assert estimator.entry_point == "iris-dnn-classifier.py" + + +def test_register_hf_pytorch_model_auto_infer_framework( + sagemaker_session, + huggingface_training_compiler_version, + huggingface_training_compiler_pytorch_version, + huggingface_training_compiler_py_version, +): + + model_package_group_name = "test-hf-tfs-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + hf_model = HuggingFaceModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + transformers_version=huggingface_training_compiler_version, + pytorch_version=huggingface_training_compiler_pytorch_version, + py_version=huggingface_training_compiler_py_version, + sagemaker_session=sagemaker_session, + ) + + hf_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": huggingface_training_compiler_pytorch_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index e80c5c395a..32dc3c5634 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -17,10 +17,11 @@ import os import pytest -from mock import MagicMock, Mock, patch +from mock import MagicMock, Mock, patch, ANY from sagemaker import image_uris from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig +from sagemaker.huggingface.model import HuggingFaceModel from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES @@ -517,3 +518,59 @@ def test_attach( ) assert estimator.source_dir == "s3://some/sourcedir.tar.gz" assert estimator.entry_point == "iris-dnn-classifier.py" + + +def test_register_hf_tfs_model_auto_infer_framework( + sagemaker_session, + huggingface_training_compiler_version, + huggingface_training_compiler_tensorflow_version, + huggingface_training_compiler_py_version, +): + + model_package_group_name = "test-hf-tfs-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + hf_model = HuggingFaceModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + transformers_version=huggingface_training_compiler_version, + tensorflow_version=huggingface_training_compiler_tensorflow_version, + py_version=huggingface_training_compiler_py_version, + sagemaker_session=sagemaker_session, + ) + + hf_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "TENSORFLOW", + "FrameworkVersion": huggingface_training_compiler_tensorflow_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index cc60b3b0ca..7cc973440f 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -18,7 +18,7 @@ from distutils.util import strtobool import pytest -from mock import MagicMock, Mock +from mock import MagicMock, Mock, ANY from mock import patch from sagemaker.chainer import defaults @@ -614,3 +614,57 @@ def test_model_py2_warning(warning, sagemaker_session, chainer_version): ) assert model.py_version == "py2" warning.assert_called_with(model._framework_name, defaults.LATEST_PY2_VERSION) + + +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_chainer_model_auto_infer_framework( + sagemaker_session, chainer_version, chainer_py_version +): + + model_package_group_name = "test-chainer-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + chainer_model = ChainerModel( + "s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + ) + + chainer_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "CHAINER", + "FrameworkVersion": chainer_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index b6a9135e5b..99b0e839b7 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -1101,3 +1101,56 @@ def test_custom_image_estimator_deploy( mx.fit(inputs="s3://mybucket/train", job_name="new_name") model = mx.create_model(image_uri=custom_image) assert model.image_uri == custom_image + + +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_mxnet_model_auto_infer_framework( + sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version +): + + model_package_group_name = "test-mxnet-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + mxnet_model = MXNetModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=mxnet_inference_version, + py_version=mxnet_inference_py_version, + sagemaker_session=sagemaker_session, + ) + + mxnet_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": MODEL_DATA, + "Framework": FRAMEWORK.upper(), + "FrameworkVersion": mxnet_inference_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index e39abf01fd..8b8541e816 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -709,3 +709,57 @@ def test_pt_heterogeneous_cluster_distribution_config( }, ) assert pytorch.distribution == expected_return + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_pytorch_model_auto_infer_framework( + sagemaker_session, pytorch_inference_version, pytorch_inference_py_version +): + + model_package_group_name = "test-pytorch-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + pytorch_model = PyTorchModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_inference_version, + py_version=pytorch_inference_py_version, + sagemaker_session=sagemaker_session, + ) + + pytorch_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": pytorch_inference_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index da89e77360..3cba43a4b7 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -17,7 +17,7 @@ import os import pytest -from mock import Mock +from mock import Mock, ANY from mock import patch from sagemaker.fw_utils import UploadedCode @@ -587,3 +587,49 @@ def test_model_py2_raises(sagemaker_session, sklearn_version): framework_version=sklearn_version, py_version="py2", ) + + +def test_register_sklearn_model_auto_infer_framework(sagemaker_session, sklearn_version): + source_dir = "s3://mybucket/source" + + model_package_group_name = "test-sklearn-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + image_uri = "fakeimage" + + sklearn_model = SKLearnModel( + model_data=source_dir, + role=ROLE, + sagemaker_session=sagemaker_session, + entry_point=SCRIPT_PATH, + framework_version=sklearn_version, + ) + + sklearn_model.register( + content_types, + response_types, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": source_dir, + "Framework": "SKLEARN", + "FrameworkVersion": sklearn_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": None, + "transform_instances": None, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 82f27c19ae..bcf3a9c9da 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -19,6 +19,7 @@ from mock import Mock from mock import patch +from mock import ANY from packaging.version import Version @@ -672,3 +673,54 @@ def test_unsupported_xgboost_version_error(sagemaker_session): error_message = "XGBoost 1.1 is not supported" assert error_message in str(error1) assert error_message in str(error2) + + +def test_register_xgboost_model_auto_infer_framework(sagemaker_session, xgboost_framework_version): + source_dir = "s3://mybucket/source" + + model_package_group_name = "test-pytorch-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + xgboost_model = XGBoostModel( + model_data=source_dir, + role=ROLE, + sagemaker_session=sagemaker_session, + entry_point=SCRIPT_PATH, + framework_version=xgboost_framework_version, + ) + + xgboost_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "XGBOOST", + "FrameworkVersion": xgboost_framework_version, + }, + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) From 03bdae5b98cad4c37cb4179244566f712db37cab Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:10:05 -0700 Subject: [PATCH 134/526] fix: Support parameterized source code input for TrainingStep (#3202) Co-authored-by: Dewen Qi --- src/sagemaker/amazon/hyperparameter.py | 18 +- src/sagemaker/chainer/estimator.py | 6 +- src/sagemaker/estimator.py | 35 ++- src/sagemaker/fw_utils.py | 57 +++- src/sagemaker/huggingface/estimator.py | 22 +- src/sagemaker/job.py | 24 +- src/sagemaker/jumpstart/utils.py | 43 ++- src/sagemaker/mxnet/estimator.py | 16 +- src/sagemaker/pytorch/estimator.py | 6 +- src/sagemaker/rl/estimator.py | 6 +- src/sagemaker/sklearn/estimator.py | 6 +- src/sagemaker/workflow/steps.py | 5 +- src/sagemaker/xgboost/estimator.py | 6 +- .../pytorch_mnist_source_code.tar.gz | Bin 0 -> 3557 bytes .../tensorflow_mnist_source_code.tar.gz | Bin 0 -> 3282 bytes .../tensorflow_mnist_source_code_dummy.tar.gz | Bin 0 -> 1707 bytes .../estimator_source_code_dummy1.tar.gz | Bin 0 -> 1664 bytes .../estimator_source_code_dummy2.tar.gz | Bin 0 -> 1486 bytes .../sagemaker/workflow/test_training_steps.py | 280 +++++++++++++----- .../sagemaker/tensorflow/test_estimator.py | 72 +++++ .../sagemaker/workflow/test_training_step.py | 93 +++++- tests/unit/test_estimator.py | 72 +++++ 22 files changed, 605 insertions(+), 162 deletions(-) create mode 100644 tests/data/pytorch_mnist/pytorch_mnist_source_code.tar.gz create mode 100644 tests/data/tensorflow_mnist/tensorflow_mnist_source_code.tar.gz create mode 100644 tests/data/tensorflow_mnist/tensorflow_mnist_source_code_dummy.tar.gz create mode 100644 tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy1.tar.gz create mode 100644 tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy2.tar.gz diff --git a/src/sagemaker/amazon/hyperparameter.py b/src/sagemaker/amazon/hyperparameter.py index 58aef71379..856927cb13 100644 --- a/src/sagemaker/amazon/hyperparameter.py +++ b/src/sagemaker/amazon/hyperparameter.py @@ -15,6 +15,8 @@ import json +from sagemaker.workflow import is_pipeline_variable + class Hyperparameter(object): """An algorithm hyperparameter with optional validation. @@ -98,8 +100,14 @@ def serialize_all(obj): """ if "_hyperparameters" not in dir(obj): return {} - return { - k: json.dumps(v) if isinstance(v, list) else str(v) - for k, v in obj._hyperparameters.items() - if v is not None - } + hps = {} + for k, v in obj._hyperparameters.items(): + if v is not None: + if isinstance(v, list): + v = json.dumps(v) + elif is_pipeline_variable(v): + v = v.to_string() + else: + v = str(v) + hps[k] = v + return hps diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 899ef62f63..12c22eae91 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Union, Optional from sagemaker.estimator import Framework, EstimatorBase from sagemaker.fw_utils import ( @@ -25,6 +26,7 @@ from sagemaker.chainer import defaults from sagemaker.chainer.model import ChainerModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -42,12 +44,12 @@ class Chainer(Framework): def __init__( self, - entry_point, + entry_point: Union[str, PipelineVariable], use_mpi=None, num_processes=None, process_slots_per_host=None, additional_mpi_options=None, - source_dir=None, + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters=None, framework_version=None, py_version=None, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index b6fd68f472..1ab122b2e0 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -47,6 +47,7 @@ get_mp_parameters, tar_and_upload_dir, validate_source_dir, + validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.job import _Job @@ -140,12 +141,12 @@ def __init__( disable_profiler: bool = False, environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, - source_dir: Optional[str] = None, + source_dir: Optional[Union[str, PipelineVariable]] = None, git_config: Optional[Dict[str, str]] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, code_location: Optional[str] = None, - entry_point: Optional[str] = None, + entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, instance_groups: Optional[Dict[str, Union[str, int]]] = None, **kwargs, @@ -461,6 +462,13 @@ def __init__( "train_volume_kms_key", "volume_kms_key", volume_kms_key, kwargs ) + validate_source_code_input_against_pipeline_variables( + entry_point=entry_point, + source_dir=source_dir, + git_config=git_config, + enable_network_isolation=enable_network_isolation, + ) + self.role = role self.instance_count = instance_count self.instance_type = instance_type @@ -663,7 +671,11 @@ def _prepare_for_training(self, job_name=None): # validate source dir will raise a ValueError if there is something wrong with # the source directory. We are intentionally not handling it because this is a # critical error. - if self.source_dir and not self.source_dir.lower().startswith("s3://"): + if ( + self.source_dir + and not is_pipeline_variable(self.source_dir) + and not self.source_dir.lower().startswith("s3://") + ): validate_source_dir(self.entry_point, self.source_dir) # if we are in local mode with local_code=True. We want the container to just @@ -2151,11 +2163,11 @@ def __init__( disable_profiler: bool = False, environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, - source_dir: Optional[str] = None, + source_dir: Optional[Union[str, PipelineVariable]] = None, git_config: Optional[Dict[str, str]] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, code_location: Optional[str] = None, - entry_point: Optional[str] = None, + entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, instance_groups: Optional[Dict[str, Union[str, int]]] = None, **kwargs, @@ -2603,8 +2615,8 @@ class Framework(EstimatorBase): def __init__( self, - entry_point: str, - source_dir: Optional[str] = None, + entry_point: Union[str, PipelineVariable], + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, code_location: Optional[str] = None, @@ -2783,7 +2795,14 @@ def __init__( """ super(Framework, self).__init__(enable_network_isolation=enable_network_isolation, **kwargs) image_uri = renamed_kwargs("image_name", "image_uri", image_uri, kwargs) - if entry_point.startswith("s3://"): + + validate_source_code_input_against_pipeline_variables( + entry_point=entry_point, + source_dir=source_dir, + git_config=git_config, + enable_network_isolation=enable_network_isolation, + ) + if not is_pipeline_variable(entry_point) and entry_point.startswith("s3://"): raise ValueError( "Invalid entry point script: {}. Must be a path to a local file.".format( entry_point diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 2fcb5d19f7..40787d4440 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -20,7 +20,7 @@ import shutil import tempfile from collections import namedtuple -from typing import Optional +from typing import Optional, Union, Dict import sagemaker.image_uris from sagemaker.session_settings import SessionSettings @@ -28,6 +28,7 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.deprecations import renamed_warning, renamed_kwargs +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -124,6 +125,58 @@ def validate_source_dir(script, directory): return True +def validate_source_code_input_against_pipeline_variables( + entry_point: Optional[Union[str, PipelineVariable]] = None, + source_dir: Optional[Union[str, PipelineVariable]] = None, + git_config: Optional[Dict[str, str]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, +): + """Validate source code input against pipeline variables + + Args: + entry_point (str, PipelineVariable): The path to the local Python source file that + should be executed as the entry point to training (default: None). + source_dir (str, PipelineVariable): The Path to a directory with any other + training source code dependencies aside from the entry point file (default: None). + git_config (Dict[str, str]): Git configurations used for cloning files (default: None). + enable_network_isolation (bool, PipelineVariable): Specifies whether container will run + in network isolation mode (default: False). + """ + if is_pipeline_variable(enable_network_isolation) or enable_network_isolation is True: + if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir): + raise TypeError( + "entry_point, source_dir should not be pipeline variables " + "when enable_network_isolation is a pipeline variable or it is set to True." + ) + if git_config: + if is_pipeline_variable(entry_point) or is_pipeline_variable(source_dir): + raise TypeError( + "entry_point, source_dir should not be pipeline variables when git_config is given." + ) + if is_pipeline_variable(entry_point): + if not source_dir: + raise TypeError( + "The entry_point should not be a pipeline variable when source_dir is missing." + ) + if not is_pipeline_variable(source_dir) and not source_dir.lower().startswith("s3://"): + raise TypeError( + "The entry_point should not be a pipeline variable when source_dir is a local path." + ) + logger.warning( + "The entry_point is a pipeline variable: %s. During pipeline execution, " + "the interpreted value of entry_point has to be a local path in the container " + "pointing to a Python source file which is located at the root of source_dir.", + type(entry_point), + ) + if is_pipeline_variable(source_dir): + logger.warning( + "The source_dir is a pipeline variable: %s. During pipeline execution, " + "the interpreted value of source_dir has to be an S3 URI and " + "must point to a tar.gz file", + type(source_dir), + ) + + def get_mp_parameters(distribution): """Get the model parallelism parameters provided by the user. @@ -265,7 +318,7 @@ def tar_and_upload_dir( sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name. """ - if directory and directory.lower().startswith("s3://"): + if directory and (is_pipeline_variable(directory) or directory.lower().startswith("s3://")): return UploadedCode(s3_prefix=directory, script_name=script) script_name = script if directory else os.path.basename(script) diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index bb43890ce4..628c14dc8e 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -15,6 +15,7 @@ import logging import re +from typing import Optional, Union, Dict from sagemaker.deprecations import renamed_kwargs from sagemaker.estimator import Framework, EstimatorBase @@ -27,6 +28,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.huggingface.training_compiler.config import TrainingCompilerConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -38,16 +40,16 @@ class HuggingFace(Framework): def __init__( self, - py_version, - entry_point, - transformers_version=None, - tensorflow_version=None, - pytorch_version=None, - source_dir=None, - hyperparameters=None, - image_uri=None, - distribution=None, - compiler_config=None, + py_version: str, + entry_point: Union[str, PipelineVariable], + transformers_version: Optional[str] = None, + tensorflow_version: Optional[str] = None, + pytorch_version: Optional[str] = None, + source_dir: Optional[Union[str, PipelineVariable]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict] = None, + compiler_config: Optional[TrainingCompilerConfig] = None, **kwargs, ): """This estimator runs a Hugging Face training script in a SageMaker training environment. diff --git a/src/sagemaker/job.py b/src/sagemaker/job.py index 1b9d46cd15..f389333c15 100644 --- a/src/sagemaker/job.py +++ b/src/sagemaker/job.py @@ -18,6 +18,7 @@ from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.local import file_input +from sagemaker.workflow import is_pipeline_variable class _Job(object): @@ -168,14 +169,14 @@ def _format_string_uri_input( target_attribute_name=None, ): """Placeholder docstring""" + s3_input_result = TrainingInput( + uri_input, + content_type=content_type, + input_mode=input_mode, + compression=compression, + target_attribute_name=target_attribute_name, + ) if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"): - s3_input_result = TrainingInput( - uri_input, - content_type=content_type, - input_mode=input_mode, - compression=compression, - target_attribute_name=target_attribute_name, - ) return s3_input_result if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"): return file_input(uri_input) @@ -185,16 +186,11 @@ def _format_string_uri_input( '"file://"'.format(uri_input) ) if isinstance(uri_input, str): - s3_input_result = TrainingInput( - uri_input, - content_type=content_type, - input_mode=input_mode, - compression=compression, - target_attribute_name=target_attribute_name, - ) return s3_input_result if isinstance(uri_input, (TrainingInput, file_input, FileSystemInput)): return uri_input + if is_pipeline_variable(uri_input): + return s3_input_result raise ValueError( "Cannot format input {}. Expecting one of str, TrainingInput, file_input or " diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index fcde252b71..2e86c3350a 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -30,7 +30,7 @@ JumpStartModelSpecs, JumpStartVersionedModelId, ) - +from sagemaker.workflow import is_pipeline_variable LOGGER = logging.getLogger(__name__) @@ -271,26 +271,41 @@ def add_jumpstart_tags( training_script_uri (Optional[str]): S3 URI for training script tarball. (Default: None). """ - + warn_msg = ( + "The URI (%s) is a pipeline variable which is only interpreted at execution time. " + "As a result, the JumpStart resources will not be tagged." + ) if inference_model_uri: - tags = add_single_jumpstart_tag( - inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags - ) + if is_pipeline_variable(inference_model_uri): + logging.warning(warn_msg, "inference_model_uri") + else: + tags = add_single_jumpstart_tag( + inference_model_uri, enums.JumpStartTag.INFERENCE_MODEL_URI, tags + ) if inference_script_uri: - tags = add_single_jumpstart_tag( - inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags - ) + if is_pipeline_variable(inference_script_uri): + logging.warning(warn_msg, "inference_script_uri") + else: + tags = add_single_jumpstart_tag( + inference_script_uri, enums.JumpStartTag.INFERENCE_SCRIPT_URI, tags + ) if training_model_uri: - tags = add_single_jumpstart_tag( - training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags - ) + if is_pipeline_variable(training_model_uri): + logging.warning(warn_msg, "training_model_uri") + else: + tags = add_single_jumpstart_tag( + training_model_uri, enums.JumpStartTag.TRAINING_MODEL_URI, tags + ) if training_script_uri: - tags = add_single_jumpstart_tag( - training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags - ) + if is_pipeline_variable(training_script_uri): + logging.warning(warn_msg, "training_script_uri") + else: + tags = add_single_jumpstart_tag( + training_script_uri, enums.JumpStartTag.TRAINING_SCRIPT_URI, tags + ) return tags diff --git a/src/sagemaker/mxnet/estimator.py b/src/sagemaker/mxnet/estimator.py index dddb797e18..3f0c054929 100644 --- a/src/sagemaker/mxnet/estimator.py +++ b/src/sagemaker/mxnet/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, Dict from packaging.version import Version @@ -29,6 +30,7 @@ from sagemaker.mxnet import defaults from sagemaker.mxnet.model import MXNetModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -41,13 +43,13 @@ class MXNet(Framework): def __init__( self, - entry_point, - framework_version=None, - py_version=None, - source_dir=None, - hyperparameters=None, - image_uri=None, - distribution=None, + entry_point: Union[str, PipelineVariable], + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + source_dir: Optional[Union[str, PipelineVariable]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict[str, str]] = None, **kwargs ): """This ``Estimator`` executes an MXNet script in a managed MXNet execution environment. diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 2cd5a0c798..07554ca798 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Union, Optional from packaging.version import Version @@ -28,6 +29,7 @@ from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -39,10 +41,10 @@ class PyTorch(Framework): def __init__( self, - entry_point, + entry_point: Union[str, PipelineVariable], framework_version=None, py_version=None, - source_dir=None, + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters=None, image_uri=None, distribution=None, diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 60307a7868..8d6a00b68e 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -16,6 +16,7 @@ import enum import logging import re +from typing import Union, Optional from sagemaker import image_uris, fw_utils from sagemaker.estimator import Framework, EstimatorBase @@ -23,6 +24,7 @@ from sagemaker.mxnet.model import MXNetModel from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -74,11 +76,11 @@ class RLEstimator(Framework): def __init__( self, - entry_point, + entry_point: Union[str, PipelineVariable], toolkit=None, toolkit_version=None, framework=None, - source_dir=None, + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters=None, image_uri=None, metric_definitions=None, diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 9174b98ade..e13fbb764c 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -26,6 +27,7 @@ from sagemaker.sklearn import defaults from sagemaker.sklearn.model import SKLearnModel from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -37,10 +39,10 @@ class SKLearn(Framework): def __init__( self, - entry_point, + entry_point: Union[str, PipelineVariable], framework_version=None, py_version="py3", - source_dir=None, + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters=None, image_uri=None, image_uri_region=None, diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index 45d38fe26d..d73a899084 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -407,7 +407,10 @@ def __init__( # To avoid this, hash the contents of the training script and include it # in the `job_name` passed to the `Estimator`, which will be used # instead of the timestamped path. - self.job_name = self._generate_code_upload_path() + if not is_pipeline_variable(estimator.source_dir) and not is_pipeline_variable( + estimator.entry_point + ): + self.job_name = self._generate_code_upload_path() @property def arguments(self) -> RequestType: diff --git a/src/sagemaker/xgboost/estimator.py b/src/sagemaker/xgboost/estimator.py index 948f32cdfe..f6f0005f1f 100644 --- a/src/sagemaker/xgboost/estimator.py +++ b/src/sagemaker/xgboost/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Union, Optional from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -25,6 +26,7 @@ ) from sagemaker.session import Session from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost import defaults from sagemaker.xgboost.model import XGBoostModel from sagemaker.xgboost.utils import validate_py_version, validate_framework_version @@ -42,9 +44,9 @@ class XGBoost(Framework): def __init__( self, - entry_point, + entry_point: Union[str, PipelineVariable], framework_version, - source_dir=None, + source_dir: Optional[Union[str, PipelineVariable]] = None, hyperparameters=None, py_version="py3", image_uri=None, diff --git a/tests/data/pytorch_mnist/pytorch_mnist_source_code.tar.gz b/tests/data/pytorch_mnist/pytorch_mnist_source_code.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..6597dc7add1c5ca8b7286442ad1bd74e680e2522 GIT binary patch literal 3557 zcmV=oOk6Kk-ZPk{xj7gRWs^hEni4G2B38dU81y!K8`dT1v82>l=gC_oe2>1^1zkjD|Z1o0R*rxtw6aRnOU|gHm7m)e?8)b3)w>Hh1|8H3R z4TYH3#7_l&;r_4u>JPtGZYau;GbWElhcjd}`eHLa?lZa1sp2-fbUuB9mGMYj~u-uQSi>Mi-3gB!B_71HB062(wv%H@F{jo&}Z z4O|1@-rZ6HB~(0xmG|$|k4_a~82^T!@jniH&Tj!fQe1^9KHw*cqr}iom2a!_);_Jx z`yJ4Bl(F&xX1a=fMJoTQQMm=UF<^U~=TMo}@c&!G|5)OOAm0R5jFq4|cV+uVv}+RYhHP)j z_Lgivku8z!UD>`7tqfXg^6;@HS6-7Rlr?!TutsEi_nky;pn+cud?cDrF#qo_KmRw- zKm$KxSi8A5+PkOZ2ZA}kTub2h&uoHHEd*rWO^J)C_>ZN2BHO#NeIweY5NM!*1{(O# zQ2zcOgdR(b+2#9Q<7$8ZXIZ_j)$eY0`S<@`^Y?!r1m5lYf6wgo2Ys_^SeDf{t-)Zc z`N;m$2IJbizJS%=|GVa(zghGD!REkhzW;v+@Pb(sC&Y=TGbd)WCEvehQJ6jZ(RAvC z)9f{3*%Q0WACo8^pH*HA7tAX5PSb?C#9;(qs|aC;{b5+|H702|PP`~|d_=fk9gk)Z zi?E+RRL9c9^O@l~iL*V3)MhlGR*X0 zOlfk+`!w4Qoq*c5E|CnFeMrx#uVuJ<`}dwc$cG)*edvU)4^uOQQ4&)p$a+Du220#1 zO?180YK?t|F>*{34YYF0=jX&+Iw7|0gbkJ$M2=!|WR3WGp#4^g`;yS`R4?%sH%X1#4g; z<5CBxP_UktPV91hw#h|Fr57;qz7bPD)$SXCb79Y-$nOeFp%Lvu2Z40`>QjVdrFS5) zfX0P!?$MXpI%pZhT1Y?B0TQT$bSb$pOb(@ewanVW4GhsU0|Wwmwvc}`C+7;vMPAxbyA z@AwRq0x<(QBq$bfz(!{@bZI!g#315?gkC^&z}U+?@Nrm|L}^UMPUyrQWd^4!)Fifs z@z{cav80|AVb|Czo=0M?Ce^GHZ&KS(LV;;=9EEglVX3;i9#=h zp$tnlQ+W4TI+^%%SQXM^oR<=s7^tEZ!1(v%(;R?d!>H}BwhxhFhHNUtn->5B9GBXPb!+8U%hKriPNK0_eY3xj?J&hgL zqhZ1{NVe3^i+~5Rp9n-ODEJ623l1^?850e?3X2vSvbFDWmNXgLQ4*;CoPa<)#0YoGG){rr>+Ama2~@nu&1+>0T<>eK3~6-sg; zVJEQ6D-}nu1cS{|w+SS5bq2vCpg|O0LI+ydYYeS9Z#<~JdM%5Tb8k%ZFe+X(74g*h zN$Pl(W+mQSurBQ6HaUPo&*Oeck~2ywAw4|z;wTKjV2zTWj9-AW0Uzqrk0Lc6h?pTK zJX8Pr=*i)pJv#WuzWN+1izF^>NMT&Kp*BQ_7uhE89X9Yk9k>v>l>f5>8wd2Nb@K{! z7Dr>sn9c20le`UO+a^eRn6J_<8QX9XKjpAoi9c$gC3lsYw#nzOUVonJ2!he8*CYf# zWd_NOWG;B->@M{DYFw!VqQv}CfZoj3vNFZ68)C9t|r-%Xa$k~z6BWhQhUI-pr4%5t_DUBN9x zzUo$t--Yqc+}uS1yffb%n_t>NxWahEpN-LjJ#D_?Y!g2oLOI_C+L!>-0Nw_Ld)i(I zi02Nj3hs~>v2+%V&#;Jva6hGZi^Ss_?`DwKtWqtzqI^N9L~p#&vpE%w|67w?+v>oQ~ss#s}02d`QYjR+3dfwlxm zpX)7poH5k)4xHGBJSB@AZkJ1;HYvLb1Ea8o@k>-5G59lX<~ISEvx=)3UD^Xrc49R^Q3Iu<(1va%Jl^86tK3*jtr-M7E$_>*KM*>GT}<-`6^r- zZgOGT8Eft3HD@l+;^I?YM&?MuU4#QpLap$`l7}oUHh2ZA7k&}VfUAf>agm%EUq@bv zUYn@0@JP;7;8KlGRUN-9P70?G2q=)v63EstZ=7giY3$V|dgVry<4XzVBTHSB%609? zIma4{TOnUn7FftR8uS%pkjG%yjwlH5=F%lE9mtdTofAI~Zg;Xzf<_ciJ8m>)+1Efw z!<2OxjrnbagP&iGvn23KXJeE{mTDzkfZ~uZjB*}+J``UrOc{z_=QupgbyW8u?Bux=eEPN^3~Sp$bR@} zbTU$(>nkOYH&=XxYyqgK3i>soyHB4y**`wv#DVrI(tK~%!K$u8_vq2y{-M2h@Z>E- zF$!Fj=I+Ct<71$D@?_`W_~7`#TS!9*jmN9SMo}Yu{Pf6v@c8NI3M9wxKyp`-Ebel?lye*Jh&jot#y)DpF|3y!|6a{+dy+5M;=D~+#S@v#X?<4#G zVW^q!`{tYPJ(QEDczQ_ifCL+=sYeZ&US<;&*hMC%CbF;oxuWyx%Ei{hoF=^?IdPA#e)rO%~t6FobQQIti+wvFidT@JD zpFtA;ui*LppZ;P*$1mUc>hDSXSF5#!_^;J8O+wnOOD4Vb@h|=8XFrziNzyU4&`A&d z3JIWd(!206;PWT=(C`1g8OwIRe<~g+#UJ2vxPY>j&Hu#m#wzxGhp1z0IWT!&`tQ5{ z{PXdDp8V!7|M=Sn|N8wOKY#sW7O)(iruE;7gMhfvzz*K@9@o|XTD`Fl|IMbM7xn)x z;C1W2Uag{|&Tjj3@38aHfjY%e6sYqN-P&#M2YcXReX@Hv{P3a|^(LRb_;kx^zac2; zS{?O3fTK^A3-G2A$`7#HQ8#XXkL&8ccB%ib8P&Rk>bEc2TOa@G#?$-_aKC5vA}=6s zbPv!+qgGpW|I<_Z*BcF*{~L9!$p3c%H^~20Z6W{bnhF z|M~YnqW;%v)uR621-y3s->k2?|2?){@T`tI^je+0_K1IW+XvTN=%656ai7aEye<(H zt;r2Qy|t*%V0Hc1Ynrwg|Me!l{}uKBPGD^}9_W%pqv(OeN9lLVQ6GrUCAgOx*BFmT zGNp;6ORD7GP!3A_o-~$R$(F)7`dP>EtTV?u6XN|9KI;D%Fv>+LV;R-j`wLVB6i`3` z1r+eM!x|qRy!Y*jQ9uC&dQptTmk=hu0nE1H0N{AJd)Tj9yj~@KfPINRIAl`6=?<#6o3DBC-C)t|F2c6#+FgvQa3jndb7Tz z75@P2dk*TYMSTX>{r+EX)aK(~Z&aJb@BeoKkIoDl+uOhlz>JiWT2KFl^&rZ+ahjA1}LU~s)M3) zg4#Z|U=S-6^bsh~_FQDB8d4~NESBW5iazmT;4;QD+zgV-G!UhR`RLJ+v+0<6w6b;}Jn4Z$bi>%I=e3Y=_h?U}$%M9D8g> zIAY&3Qzkkt!ERm@Fs#x&IsP8lh(Dwj+45R4Kj31J!i5*btz%KU7r1=I9M z$u#8+$Su?;ZG zZXSejiDQQL1!-01mwY@Jz)4&31hhMl=ND;s4M^Bhj0!T!^J}*2$B`MHW1qB?uIGX^ zj7|R07f=T{!sd_~Flx!xId(yq&;)%&3;Qtx;WLM{`aw*xwQ^}*2WKwTz-7N!rYqyk z>1vfaZhXd-3cp$b3+Fd(bp9a^55N_HatXnLH5N;Xu2mMc7h0!&YhlM>K3hT4oZC)$ zCV^y0-9VXEkg~4l{h&ODn)$CXU9tjJU7QEzt_%E~cqs+5?-jf*&0oVy=xX_zm7fks z6bCMvmP(}^U^fF6~#1Yh7s|pHf9y%(unqaX!gOOAo)687@A=`9^(L3q0q$v zwr0tU&@Ig{On_}3X6(ys8@R^=%FG$08#)8Lw?!iJ&YlyCWoPXau!i5vB)JCs?rho{V8EOK zLLsYwCT%Q6USNX*npQxU+$exPWImQ>Ce=&}6go4ZbPrk=6+|nH>3AdPv3M{cJhj8t zy1pc+G8uq4wum_<*j4B(#~Hz~7y=wd9f>otb!K2SjT>`<1Diew!qf=tPpxNWLK`X< zwy!X@S91{Y_XvCdGEY1Py39;@g9=G6W6#-Y(bp7E*98}od-O7YCXl5rx$}*D_z{l< zn(7{7hc}9z>p%xe$N(BP=rd@u1qC#qKFtb0bEAPh91{A^<^`KGa*ii9c-|lmSU)J( zBHF&h&@H$O9hz1JjH9A`@DDGJ%4%uB<4W_U{$}p>^lYA%zM%_QZ+Jzp$NNj z#!=N^MAcC>AzEc6<=KI3XG;dwk9I)tf8E30R;X;SLCk#;xKLN}9@ zt1H<@5VG-`$~wn`i>7P&=oY7_a1;^iJXgUjF;qm5vzY)mC09qa9gJ8Fc*<5p-^bBZ z=s+noTUX@ub>e&0c_>5w2c~X;o#}Q%96M1fk>(0YCR{=YWhWzq03`0j4m|fWxp!>t z9k#pOgCnzl+U|5a-6!%FEBUgfhK6j8_i6$>U8iMg`6&!-^)cuba6P-~y#}Y~Jn*8( zA?qQrmUC#VW^nF>QA+5*O*NyWH+`1oLS;d_hbO&$4^+C6v-rQ>FdfUq|JvSht=(k4LNtmnbS(uJUw{QIqAw@K&;7exc-ua(Yhfo#qK{g z_ntmAk2_uS=;Vocbnwx^5hv7_69D^Xy#w>G-+$Um5tB|IzVep~SH6zb_L5Sbd;ntnyFOy5pa8*hy=ewM_iUJa3og>?UtqLED z$0n~0yt06+vTtmW-UV&?{x-nw2pWVGHy1k;hjk@(QpmEk%wbmCmNOTtEO=M$mC)QS z^icDgEWN^{u2Pxzk{ny!oz1^9k`L$+(O32d`fnFr-3;YEYdnsC={EW*lbp5zkh7HX zRTPQSrBsP9t0+O}bgSen*Wu?k${%_uT}~58#uOFAq!om|0A`b{V11^2F|YKzm#ehA z=H5Ip*MjuLrsWMV_r}=SVBWf+(gkqo)At+WT9dN0s>?`^NW_XqUUOHrVQ-uolkA-S?-$ClmNa$frZ zi{WGE+u6s=aOL$qvb#ZvL!)`wlBLN&Q4~W_Kr&JcrNW^qDySt**R`TnES5k~%cvm1 z?u6_)fL_cq`x&Uq0si+H!WkoL< z%0gZ>podRN3&w)JsOyXQq5@s!c4t&^+8ERQJ%7&&Zi$x>SY^KD+eQuLJ*2eEy*Smx@MF0%hWsNpHRXOYeN}t~4b{cZm&~4fqr|l4sIY zEE<*{uq5|?o{rLNHtYN{(fAI_#-S>?;q&+GpevIw^r+k=wucWdOMhSe{l~k1-TL~c z-+un&k8i(!a`tWt9F2SX`QPqE5%rnlM(5qeN%_B|8;AVgFjQUS|BJwx^S@dw!ky}} zSzpYy%%k%d}cYv!6tHFYZo@*O89_3PL-2QJk z|5{m3;(xiMiuivKI7R%|77xb%vSG+7exjF*Mg7eSh;JF>iFy4H1OMNXj*kDTaoGRM zxV*&v??P~z{olHj?f=q%JgBjB{W!{z{on3DOW@VSf&4$AIB-(_FX9&BQ2xj4BJ%%5 z;H>$7K_9pKTX%h2SXI9jRH`dxJ1fiP+6mh^%m>Hp=0*rlibLYlS|Nl}A9TZI| zm&N(N3&H<5|6SHpS=V)}TP%qKfUh~^iFy4HC!GH)szvRf|Eq>t7U%yi1UFzM2>X%S z?l904bzx4!=nH0-JP-Wj+|4XheOrbg0++==S}oUe38S$LrsqLD(Tkx@aYVMMBj<1A zZ@?YbrhZHv=p|=$fOROCA+b^LH?!~n2ha_C&}0RslL`g?sxTwNqo9W@UD5|XV9<*x zQn)c-%M|FYO+yAa320l0rua669hY^|W;hAjhOfgu1T7>Qc>ze{1vc960AZ=I2|A30 zH|OWcPArrCj?a_SX0&iVt(c5}pOO_kZfj-|u@h5>-*^4?&{7$i`$oGl>O1IVk-ODP z+Qyy}9|~pzs*M6Hn~iE?79Lic8=Lo=@X)N+aj{j`8n9W1mCf2}wOQS);oUmYn6*c6 zyINbF1&Y1K;ibD_MB_NMFBuV>1!z!upxdn=n}#@~w%c-T=9uD*-|KcCMR3Wa<+7O@?< z$=o43q~T`pWT)fW9f%Y2`w&L1k4q}(v9vkY4|{1Ny)N~coU-Nf4sAi&s)cXj>6x4P zWFFHe5mfeVl1;96FMTsTGg3_^j$;jYrwem)Gz{!cT$lyc4`~IBm<4Qx^gLEkm9d&E z!l*D%#91CH2C+>2Z8r-1&kBvZ*2;!itF7I!nsu{UtJZE6J|D~54IJuOjvI||k5?eO zb1|dg(Y?@=@!g2YHl4#IZ$AwKR6GL{>qG%#9-WJ+eKbRPJV7Uj*}kHOo-@=;R+a1p z#EGX5jqb)~quIbA9m~xB%Q0=?R8w}pUSF#jJ?a|VTav^zoje)%O5$eIh14-eJQ5>2 zE6j}4NL7%-ZGmO^q)ROeDitVLR>Ejmg^VtxA8mF9G{2nf8=_!lt6?yUew007bC BW;Xx; literal 0 HcmV?d00001 diff --git a/tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy1.tar.gz b/tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy1.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..c0b86ee614c622c84ec49809350c146c4bf41b79 GIT binary patch literal 1664 zcmV-`27mb7+ugCf)ZTS=*Nq*J$OX89)H_Is3m|dgDGvMrNJv}|5(mV!+&RMR?)VWm zshuWvADLe?`r_T0Z@+nb^Uc_^>sn(OYX){^xB`cU1JCmbo(Gf&=t!O-@B#=@TvjB- z&k4Ms@CgAyV@USwq3(K`LrmE+4ZMx51M{h?`CAKiRWmKvaIxd|-Vdt+ZTXZSB%x#) zvqCbZ@UocTlTjf7b#g^YCX#YWmQzuYhpJ~*@vNZ8ikK9nl)}aNxSSG2QHnn93NVhu z@Vvf5%l~)D|F51sSMuf7sds*9`9B#(J^n9=2?+r|bj#3*_kZ-8chNL)2Uy$T~{&O&uYCfN(|N15Fk+j^CCDOV6f^JtityZgGuBz!3a&a8pKlA&S zSN?kA-LHQ8=Hee8e{yRa{qSPp^Y1@g`}dxn{q)?O{Nsb`G!l$?EgV@s`QLUu(xBvFBvAx8{Y+1%fN_#v1cU>&b6pDiO+;*IG63AryuQQ3Q8`ZK1eKFi0+n-A?uL?1S`oVWh|niT=nf@9zZZxAm2-!xGJ|nqI2DRXKmISG z1NZ+72ICnfGQN;oNTc?K-~s$t3naZo^ZeW@c)&tnEd0XoWz@+BRL)Vk8%nkqU@#br zQL^{^pHZ(?n*vvB9_up>bN)}_<>36c%zpnL36JpmKgr7pJh_)jh)MP%`_m0KG_UW_ z@BE(>mwNnPAs^W9|07|L^M8M%U?CqlJDAGO`M;pVPbA&w$udT-lU4aEox?p9&GqUo zHVn0|0tM(dArth%u?BdI4>k$kMnUJ)T*{)TmHbDhK!|iU^`%XE@;J~Sux4rbU9Ef z70CMKT*zY5B6qRDMW>=uuwv@iB7+#}{!v}<${5s|rV~N$Ov5$OfTnFhoa5oVpCLwX z#V&BLY1fICs@4R{_JFtpQDM43o{RxEbX@a*bON!hMoh7E4BMtx4jx055ZlP@S`+Ld zQPgzAoqN z)ikYA$5IaRZKF+>ls8P*bIjXy|JkJGngua62bo+9(rPY~n}+L|{PNmH9 z#T=|MLX;U zw}y4IXzD~yMYBqEt%RXu-^GsQ+g`&?)pY%SAsuZQP%*2fr+LBWPFJ?tI2w(`V$qg$^%+5lrQ1FWv3bC+!sEh4T84LzvKmG?} K=5xpZOaK7wv0wrK literal 0 HcmV?d00001 diff --git a/tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy2.tar.gz b/tests/data/xgboost_abalone/estimator_source_code/estimator_source_code_dummy2.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..6182c500448f47e5ca536772182ad5f715dac50c GIT binary patch literal 1486 zcmV;<1u^;`iwFSuV7g)e1MQmKZsSB0fSqm;#az&f6+$4Sqf|({s5t(2vQnfvZCABe z+HTVQDMDmjdy-g+f7W)k$#Ow&;F6c%0k}j$T<`!Z@ebU=BOt+yJ!#r(n}*#cyG!S* zMyIi7&YYR?IcJ>AuHD~4wu|DD(~Dyiq^=ipmy9TihNb~$WQ?Rpkwgh3Rn~M>7ZpPS zQ8Elk=b(SthMYw&N$nV$^4N3HAqvi|XKwF(wGsNZ7r?y)#mV^nd?;`%-w>oSv^}>f zl`FcaDTY|)B?Ef+idr_xT1C?;ydr|1dOoU3x~3~-NiSD~vZhIzZm0(Tst3S$DTLDO z`xyHF5&r+<>z}lVe0=SjXX1!4|5s(z;6R(bX;>1`_MGC1`#<;A`)_j#9QTFoz)k}m0)<@T-oZu2c-j14*5Hm+GwbVH*8NSPZ>MQ2IH51t zQPf4kZ7;wNmKC(ZsYcP*+~Cjse~|2wcPstnRdSaAe2fay`z+ ze<#|9&OtXkpr)C)2>)bEJ3t15!C){LXHv+;(z~zL3j6qu2FfEOB%Ea^yQ;KdoIv7lmh)?pa4{^{ih)_ z7*~dCxmYIof0H}+{hz^L{EsOtZZtM(-0=&+bC6sMxIChDej7GCXCttXe_{9mH>w9z zu2FfEOSTzcFc^$^a{BzA+w=W_EJTA#?c?I}f5Y?NnxX31@BfC*&i~DY7dZbdD~6zn zvZ9KrTw%Ymzu5>g^ZF4No&Qs1bDS=ZR_|8rrQ^M8ME*+#x_eL9t+^M6^3UvYUj zR%L-+$Gh?mN2hzbc9QlIR(2b-?lWArIPSesJ>Pc7p@Od;{X$L!a6hd4{SlFi z+^4m_lHY$ezy0yg?|!#Bg z)oRT&AK&fupFTC*T1(%JJ9ef1`va-Ol{3dSSctQsAy>=%B zS(M73$)YTR>Dy1kfIOJ1UifjYu7@&vg8zrwG; z9nV1leu$w*&gw$iL10F$4`hHpm$2wPaui5q~`bVCQ*?}D9X#u7Sd z8r>?D?88K`j|X0|Gn=u6rK}>Efgn>>>{45ElWd2WVg~~+XpbxvVDH#yFTw3yFG*wX zphx<~Bew%6nhmfTMW~q#tFa0Xti7$BdwcM}tk?0wW^FcLrw;2o+Z)!NwX=;s-NtIn z?S1&%+TK_Ngh!317xkkUB}rxujM&dPuC9GJ-dyd9X>&E zKx~gt?0X4W7rfAc3ti9mQajB)kCt*c$2?yw7Wr(Di4%KdQE)TJqRH^eVaIbikPy2D z5XD}ApOvte;%Y6gkLSV1hGQFqOcl?&s0GBqZrr@ZlXb~j+$d&!oKTLRzEm#W#0|U_ oG@F6#qh=GTRVX%_q*c&Q?!C;)lzn?4oKLA7k00jp0N&o-= literal 0 HcmV?d00001 diff --git a/tests/integ/sagemaker/workflow/test_training_steps.py b/tests/integ/sagemaker/workflow/test_training_steps.py index 100d234d95..20a11957c9 100644 --- a/tests/integ/sagemaker/workflow/test_training_steps.py +++ b/tests/integ/sagemaker/workflow/test_training_steps.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import os -import re import uuid import logging @@ -28,6 +27,8 @@ ) from sagemaker.estimator import Estimator from sagemaker.pytorch.estimator import PyTorch +from sagemaker.tensorflow import TensorFlow +from sagemaker.utils import sagemaker_timestamp from sagemaker.workflow.functions import Join from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -67,11 +68,17 @@ def test_training_job_with_debugger_and_profiler( Rule.sagemaker(rule_configs.loss_not_decreasing()), ] debugger_hook_config = DebuggerHookConfig( - s3_output_path=(f"s3://{sagemaker_session.default_bucket()}/{uuid.uuid4()}/tensors") + s3_output_path=f"s3://{sagemaker_session.default_bucket()}/{uuid.uuid4()}/tensors" ) base_dir = os.path.join(DATA_DIR, "pytorch_mnist") - script_path = os.path.join(base_dir, "mnist.py") + entry_point = "mnist.py" + source_dir = sagemaker_session.upload_data( + path=os.path.join(base_dir, "pytorch_mnist_source_code.tar.gz"), + key_prefix="integ-test-data/pytorch_mnist/training", + ) + entry_point_param = ParameterString(name="EntryPoint") + source_dir_param = ParameterString(name="SourceDir") input_path = sagemaker_session.upload_data( path=os.path.join(base_dir, "training"), key_prefix="integ-test-data/pytorch_mnist/training", @@ -81,7 +88,8 @@ def test_training_job_with_debugger_and_profiler( # If image_uri is not provided, the instance_type should not be a pipeline variable # since instance_type is used to retrieve image_uri in compile time (PySDK) pytorch_estimator = PyTorch( - entry_point=script_path, + entry_point=entry_point_param, + source_dir=source_dir_param, role="SageMakerRole", framework_version=pytorch_training_latest_version, py_version=pytorch_training_latest_py_version, @@ -90,6 +98,9 @@ def test_training_job_with_debugger_and_profiler( sagemaker_session=sagemaker_session, rules=rules, debugger_hook_config=debugger_hook_config, + # TODO: remove base_job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + base_job_name="TestJob", ) step_train = TrainingStep( @@ -100,80 +111,74 @@ def test_training_job_with_debugger_and_profiler( pipeline = Pipeline( name=pipeline_name, - parameters=[instance_count], + parameters=[instance_count, entry_point_param, source_dir_param], steps=[step_train], sagemaker_session=sagemaker_session, ) - for _ in retries( - max_retry_count=5, - exception_message_prefix="Waiting for a successful execution of pipeline", - seconds_to_sleep=10, - ): - try: - response = pipeline.create(role) - create_arn = response["PipelineArn"] - - execution = pipeline.start() - response = execution.describe() - assert response["PipelineArn"] == create_arn - - try: - execution.wait(delay=10, max_attempts=60) - except WaiterError: - pass - execution_steps = execution.list_steps() - - assert len(execution_steps) == 1 - failure_reason = execution_steps[0].get("FailureReason", "") - if failure_reason != "": - logging.error(f"Pipeline execution failed with error: {failure_reason}.Retrying..") - continue - assert execution_steps[0]["StepName"] == "pytorch-train" - assert execution_steps[0]["StepStatus"] == "Succeeded" - - training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"] - job_description = sagemaker_session.sagemaker_client.describe_training_job( - TrainingJobName=training_job_arn.split("/")[1] - ) + try: + pipeline.create(role) + execution_steps = _start_and_verify_execution_with_retry( + pipeline=pipeline, + parameters={"EntryPoint": entry_point, "SourceDir": source_dir}, + ) + training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"] + job_description = sagemaker_session.sagemaker_client.describe_training_job( + TrainingJobName=training_job_arn.split("/")[1] + ) - for index, rule in enumerate(rules): - config = job_description["DebugRuleConfigurations"][index] - assert config["RuleConfigurationName"] == rule.name - assert config["RuleEvaluatorImage"] == rule.image_uri - assert config["VolumeSizeInGB"] == 0 - assert ( - config["RuleParameters"]["rule_to_invoke"] - == rule.rule_parameters["rule_to_invoke"] - ) - assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() - - assert job_description["ProfilingStatus"] == "Enabled" - assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 - break - finally: - try: - pipeline.delete() - except Exception: - pass + for index, rule in enumerate(rules): + config = job_description["DebugRuleConfigurations"][index] + assert config["RuleConfigurationName"] == rule.name + assert config["RuleEvaluatorImage"] == rule.image_uri + assert config["VolumeSizeInGB"] == 0 + assert ( + config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"] + ) + assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() + assert job_description["ProfilingStatus"] == "Enabled" + assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 + finally: + try: + pipeline.delete() + except Exception as error: + logging.error(error) def test_training_step_with_output_path_as_join( sagemaker_session, role, tf_full_version, tf_full_py_version, pipeline_name, region_name ): - base_dir = os.path.join(DATA_DIR, "dummy_tensor") input_path = sagemaker_session.upload_data( - path=base_dir, key_prefix="integ-test-data/estimator/training" + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", ) - inputs = TrainingInput(s3_data=input_path) + inputs = {"train": TrainingInput(s3_data=input_path)} instance_count = ParameterInteger(name="InstanceCount", default_value=1) instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge") + entry_point1 = "dummy1" + entry_point2 = "dummy2" + src_base_dir = os.path.join(DATA_DIR, "xgboost_abalone/estimator_source_code") + source_dir1 = sagemaker_session.upload_data( + path=os.path.join(src_base_dir, "estimator_source_code_dummy1.tar.gz"), + key_prefix="integ-test-data/estimator/training", + ) + source_dir2 = sagemaker_session.upload_data( + path=os.path.join(src_base_dir, "estimator_source_code_dummy2.tar.gz"), + key_prefix="integ-test-data/estimator/training", + ) + entry_point_param = ParameterString(name="EntryPoint") + source_dir_param = ParameterString(name="SourceDir") output_path = Join( on="/", values=["s3:/", f"{sagemaker_session.default_bucket()}", f"{pipeline_name}Train"] ) - - image_uri = image_uris.retrieve("factorization-machines", sagemaker_session.boto_region_name) + image_uri = image_uris.retrieve( + framework="xgboost", + region=sagemaker_session.boto_session.region_name, + version="1.0-1", + py_version="py3", + instance_type="ml.m5.xlarge", + ) estimator = Estimator( image_uri=image_uri, role=role, @@ -181,47 +186,162 @@ def test_training_step_with_output_path_as_join( instance_type=instance_type, sagemaker_session=sagemaker_session, output_path=output_path, + source_dir=source_dir_param, + entry_point=entry_point_param, + # TODO: remove base_job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + base_job_name="TestJob", ) estimator.set_hyperparameters( - num_factors=10, feature_dim=784, mini_batch_size=100, predictor_type="binary_classifier" + objective="reg:linear", + num_round=50, + max_depth=5, + eta=0.2, + gamma=4, + min_child_weight=6, + subsample=0.7, ) step_train = TrainingStep( name="MyTrain", estimator=estimator, inputs=inputs, ) - pipeline = Pipeline( name=pipeline_name, - parameters=[instance_count, instance_type], + parameters=[instance_count, instance_type, source_dir_param, entry_point_param], steps=[step_train], sagemaker_session=sagemaker_session, ) try: - response = pipeline.create(role) - create_arn = response["PipelineArn"] - - assert re.match( - rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", - create_arn, + pipeline.create(role) + # execution1 + _start_and_verify_execution_with_retry( + pipeline=pipeline, + parameters={"EntryPoint": entry_point1, "SourceDir": source_dir1}, ) + # execution2 updates parameters to different values + _start_and_verify_execution_with_retry( + pipeline=pipeline, + parameters={"EntryPoint": entry_point2, "SourceDir": source_dir2}, + ) + finally: + try: + pipeline.delete() + except Exception as error: + logging.error(error) - execution = pipeline.start(parameters={}) - assert re.match( - rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/", - execution.arn, + +def test_tensorflow_training_step_with_parameterized_code_input( + pipeline_session, role, tf_full_version, tf_full_py_version, pipeline_name +): + base_dir = os.path.join(DATA_DIR, "tensorflow_mnist") + entry_point1 = "mnist_v2.py" + entry_point2 = "mnist_dummy.py" + source_dir1 = pipeline_session.upload_data( + path=os.path.join(base_dir, "tensorflow_mnist_source_code.tar.gz"), + key_prefix="integ-test-data/tf-scriptmode/mnist/training", + ) + source_dir2 = pipeline_session.upload_data( + path=os.path.join(base_dir, "tensorflow_mnist_source_code_dummy.tar.gz"), + key_prefix="integ-test-data/tf-scriptmode/mnist/training", + ) + entry_point_param = ParameterString(name="EntryPoint") + source_dir_param = ParameterString(name="SourceDir") + input_path = pipeline_session.upload_data( + path=os.path.join(base_dir, "data"), + key_prefix="integ-test-data/tf-scriptmode/mnist/training", + ) + inputs = TrainingInput(s3_data=input_path) + instance_count = ParameterInteger(name="InstanceCount", default_value=1) + output_path = ParameterString( + name="OutputPath", default_value=f"s3://{pipeline_session.default_bucket()}" + ) + checkpoint_s3_uri1 = "s3://{}/checkpoints/tf1-{}".format( + pipeline_session.default_bucket(), sagemaker_timestamp() + ) + checkpoint_s3_uri2 = "s3://{}/checkpoints/tf2-{}".format( + pipeline_session.default_bucket(), sagemaker_timestamp() + ) + checkpoint_s3_param = ParameterString(name="CheckpointS3Uri") + + # If image_uri is not provided, the instance_type should not be a pipeline variable + # since instance_type is used to retrieve image_uri in compile time (PySDK) + tensorflow_estimator = TensorFlow( + entry_point=entry_point_param, + source_dir=source_dir_param, + role=role, + instance_count=instance_count, + instance_type="ml.m5.xlarge", + framework_version=tf_full_version, + py_version=tf_full_py_version, + sagemaker_session=pipeline_session, + output_path=output_path, + checkpoint_s3_uri=checkpoint_s3_param, + ) + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + train_step_args = tensorflow_estimator.fit(inputs=inputs, job_name="TestJob") + step_train = TrainingStep( + name="MyTrain", + step_args=train_step_args, + ) + pipeline = Pipeline( + name=pipeline_name, + parameters=[ + instance_count, + output_path, + entry_point_param, + source_dir_param, + checkpoint_s3_param, + ], + steps=[step_train], + sagemaker_session=pipeline_session, + ) + + try: + pipeline.create(role) + # execution1 + _start_and_verify_execution_with_retry( + pipeline=pipeline, + parameters={ + "EntryPoint": entry_point1, + "SourceDir": source_dir1, + "CheckpointS3Uri": checkpoint_s3_uri1, + }, + ) + # execution2 updates parameters to different values + _start_and_verify_execution_with_retry( + pipeline=pipeline, + parameters={ + "EntryPoint": entry_point2, + "SourceDir": source_dir2, + "CheckpointS3Uri": checkpoint_s3_uri2, + }, ) + finally: + try: + pipeline.delete() + except Exception as error: + logging.error(error) + + +def _start_and_verify_execution_with_retry(pipeline: Pipeline, parameters: dict) -> list: + for _ in retries( + max_retry_count=5, + exception_message_prefix="Waiting for a successful execution of pipeline", + seconds_to_sleep=10, + ): + execution = pipeline.start(parameters=parameters) try: execution.wait(delay=30, max_attempts=60) except WaiterError: pass execution_steps = execution.list_steps() - assert len(execution_steps) == 1 - assert execution_steps[0]["StepName"] == "MyTrain" - finally: - try: - pipeline.delete() - except Exception: - pass + failure_reason = execution_steps[0].get("FailureReason", "") + if failure_reason != "": + logging.error(f"Pipeline execution failed with error: {failure_reason}." " Retrying..") + continue + assert execution_steps[0]["StepStatus"] == "Succeeded" + return execution_steps diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index d27359f010..2e7576421f 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -23,6 +23,7 @@ from sagemaker.estimator import _TrainingJob from sagemaker.tensorflow import TensorFlow from sagemaker.instance_group import InstanceGroup +from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from tests.unit import DATA_DIR SCRIPT_FILE = "dummy_script.py" @@ -560,3 +561,74 @@ def test_tf_heterogeneous_cluster_distribution_config( }, ) assert tf.distribution == expected_return + + +def test_insert_invalid_source_code_args(): + with pytest.raises(TypeError) as err: + TensorFlow( + image_uri="IMAGE_URI", + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + enable_network_isolation=True, + ) + assert ( + "entry_point, source_dir should not be pipeline variables " + "when enable_network_isolation is a pipeline variable or it is set to True." + ) in str(err.value) + + with pytest.raises(TypeError) as err: + TensorFlow( + image_uri="IMAGE_URI", + role=ROLE, + entry_point="dummy.py", + source_dir=ParameterString(name="SourceDir"), + instance_type="ml.m5.xlarge", + instance_count=1, + enable_network_isolation=ParameterBoolean(name="EnableNetworkIsolation"), + ) + assert ( + "entry_point, source_dir should not be pipeline variables " + "when enable_network_isolation is a pipeline variable or it is set to True." + ) in str(err.value) + + with pytest.raises(TypeError) as err: + TensorFlow( + image_uri="IMAGE_URI", + role=ROLE, + git_config={"repo": "REPO", "branch": "BRANCH", "commit": "COMMIT"}, + source_dir=ParameterString(name="SourceDir"), + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert ( + "entry_point, source_dir should not be pipeline variables when git_config is given" + in str(err.value) + ) + + with pytest.raises(TypeError) as err: + TensorFlow( + image_uri="IMAGE_URI", + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert ( + "The entry_point should not be a pipeline variable " "when source_dir is missing" + ) in str(err.value) + + with pytest.raises(TypeError) as err: + TensorFlow( + image_uri="IMAGE_URI", + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + source_dir="file://my-file/", + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert ( + "The entry_point should not be a pipeline variable " "when source_dir is a local path" + ) in str(err.value) diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 397e65f867..f043048095 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -64,7 +64,6 @@ MODEL_NAME = "gisele" DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py" -DUMMY_S3_SOURCE_DIR = "s3://dummy-s3-source-dir/" INSTANCE_TYPE = "ml.m4.xlarge" ESTIMATOR_LISTS = [ @@ -74,7 +73,8 @@ instance_type=INSTANCE_TYPE, instance_count=1, role=ROLE, - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), ), PyTorch( role=ROLE, @@ -82,11 +82,13 @@ instance_count=1, framework_version="1.8.0", py_version="py36", - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), ), TensorFlow( role=ROLE, - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), instance_type=INSTANCE_TYPE, instance_count=1, framework_version="2.0", @@ -99,7 +101,8 @@ instance_type="ml.p3.2xlarge", instance_count=1, py_version="py36", - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), ), XGBoost( framework_version="1.3-1", @@ -107,7 +110,8 @@ role=ROLE, instance_type=INSTANCE_TYPE, instance_count=1, - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), ), MXNet( framework_version="1.4.1", @@ -115,13 +119,15 @@ role=ROLE, instance_type=INSTANCE_TYPE, instance_count=1, - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), toolkit=RLToolkit.RAY, framework=RLFramework.TENSORFLOW, toolkit_version="0.8.5", ), RLEstimator( - entry_point="cartpole.py", + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), toolkit=RLToolkit.RAY, framework=RLFramework.TENSORFLOW, toolkit_version="0.8.5", @@ -131,7 +137,8 @@ ), Chainer( role=ROLE, - entry_point=DUMMY_LOCAL_SCRIPT_PATH, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), use_mpi=True, num_processes=4, framework_version="5.0.0", @@ -217,7 +224,9 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar ) with warnings.catch_warnings(record=True) as w: - step_args = estimator.fit(inputs=training_input) + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + step_args = estimator.fit(inputs=training_input, job_name="TestJob") assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert "Running within a PipelineSession" in str(w[-1].message) @@ -257,6 +266,57 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar ) +def test_training_step_estimator_with_param_code_input( + pipeline_session, training_input, hyperparameters +): + entry_point = ParameterString(name="EntryPoint") + source_dir = ParameterString(name="SourceDir") + estimator = Estimator( + role=ROLE, + instance_count=1, + instance_type=INSTANCE_TYPE, + sagemaker_session=pipeline_session, + image_uri=IMAGE_URI, + hyperparameters=hyperparameters, + entry_point=entry_point, + source_dir=source_dir, + ) + + with warnings.catch_warnings(record=True) as w: + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + step_args = estimator.fit(inputs=training_input, job_name="TestJob") + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Running within a PipelineSession" in str(w[-1].message) + + with warnings.catch_warnings(record=True) as w: + step = TrainingStep( + name="MyTrainingStep", + step_args=step_args, + description="TrainingStep description", + display_name="MyTrainingStep", + ) + assert len(w) == 0 + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, + ) + step_args.args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} + step_args.args["HyperParameters"]["sagemaker_submit_directory"] = { + "Get": "Parameters.SourceDir" + } + assert json.loads(pipeline.definition())["Steps"][0] == { + "Name": "MyTrainingStep", + "Description": "TrainingStep description", + "DisplayName": "MyTrainingStep", + "Type": "Training", + "Arguments": step_args.args, + } + + @pytest.mark.parametrize("estimator", ESTIMATOR_LISTS) @pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS) @pytest.mark.parametrize( @@ -265,12 +325,14 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar def test_training_step_with_framework_estimator( estimator, pipeline_session, training_input, output_path, hyperparameters ): - estimator.source_dir = DUMMY_S3_SOURCE_DIR estimator.set_hyperparameters(**hyperparameters) estimator.volume_kms_key = "volume-kms-key" estimator.output_kms_key = "output-kms-key" estimator.dependencies = ["dep-1", "dep-2"] estimator.output_path = output_path + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + estimator.base_job_name = "TestJob" estimator.sagemaker_session = pipeline_session step_args = estimator.fit(inputs=TrainingInput(s3_data=training_input)) @@ -290,6 +352,8 @@ def test_training_step_with_framework_estimator( assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input assert step_args["OutputDataConfig"]["S3OutputPath"] == output_path + step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} + step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"} del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] @@ -333,6 +397,11 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel instance_type=INSTANCE_TYPE, instance_count=1, sagemaker_session=pipeline_session, + entry_point=ParameterString(name="EntryPoint"), + source_dir=ParameterString(name="SourceDir"), + # TODO: remove job_name once we merge + # https://github.com/aws/sagemaker-python-sdk/pull/3158/files + base_job_name="TestJob", ) data = RecordSet( s3_data=training_input, @@ -367,6 +436,8 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel step_def = json.loads(pipeline.definition())["Steps"][0] assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input + step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"} + step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"} del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 859cdb941f..5906dcb0a1 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -52,6 +52,7 @@ from sagemaker.tensorflow.estimator import TensorFlow from sagemaker.predictor_async import AsyncPredictor from sagemaker.transformer import Transformer +from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.xgboost.estimator import XGBoost @@ -4214,3 +4215,74 @@ def test_all_framework_estimators_add_jumpstart_base_name( sagemaker_session.endpoint_from_production_variants.reset_mock() sagemaker_session.create_model.reset_mock() sagemaker_session.train.reset_mock() + + +def test_insert_invalid_source_code_args(): + with pytest.raises(TypeError) as err: + Estimator( + image_uri="IMAGE_URI", + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + enable_network_isolation=True, + ) + assert ( + "entry_point, source_dir should not be pipeline variables " + "when enable_network_isolation is a pipeline variable or it is set to True." + ) in str(err.value) + + with pytest.raises(TypeError) as err: + Estimator( + image_uri="IMAGE_URI", + role=ROLE, + entry_point="dummy.py", + source_dir=ParameterString(name="SourceDir"), + instance_type="ml.m5.xlarge", + instance_count=1, + enable_network_isolation=ParameterBoolean(name="EnableNetworkIsolation"), + ) + assert ( + "entry_point, source_dir should not be pipeline variables " + "when enable_network_isolation is a pipeline variable or it is set to True." + ) in str(err.value) + + with pytest.raises(TypeError) as err: + Estimator( + image_uri=IMAGE_URI, + role=ROLE, + git_config={"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT}, + source_dir=ParameterString(name="SourceDir"), + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert ( + "entry_point, source_dir should not be pipeline variables when git_config is given" + in str(err.value) + ) + + with pytest.raises(TypeError) as err: + Estimator( + image_uri=IMAGE_URI, + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert "The entry_point should not be a pipeline variable when source_dir is missing" in str( + err.value + ) + + with pytest.raises(TypeError) as err: + Estimator( + image_uri="IMAGE_URI", + role=ROLE, + entry_point=ParameterString(name="EntryPoint"), + source_dir="file://my-file/", + instance_type="ml.m5.xlarge", + instance_count=1, + ) + assert ( + "The entry_point should not be a pipeline variable " "when source_dir is a local path" + ) in str(err.value) From 8f78094e23e93706b7aec5333c7ad773fcbfdd11 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:11:02 -0700 Subject: [PATCH 135/526] change: Make repack step output path align with model repack path (#3257) Co-authored-by: Dewen Qi --- src/sagemaker/image_uris.py | 10 ++++- src/sagemaker/model.py | 6 ++- src/sagemaker/rl/estimator.py | 5 +++ src/sagemaker/tensorflow/model.py | 16 ++++--- src/sagemaker/workflow/model_step.py | 4 +- src/sagemaker/workflow/pipeline_context.py | 7 +-- src/sagemaker/workflow/utilities.py | 12 +++--- .../sagemaker/workflow/test_model_steps.py | 1 + .../sagemaker/image_uris/test_retrieve.py | 2 +- .../sagemaker/workflow/test_model_step.py | 43 ++++++++++++++++--- 10 files changed, 79 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index ec1fec2d20..01ed5f1d99 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -32,7 +32,6 @@ HUGGING_FACE_FRAMEWORK = "huggingface" -# TODO: we should remove this decorator later @override_pipeline_parameter_var def retrieve( framework, @@ -117,7 +116,11 @@ def retrieve( args = dict(locals()) for name, val in args.items(): if is_pipeline_variable(val): - raise ValueError("%s should not be a pipeline variable (%s)" % (name, type(val))) + raise ValueError( + "When retrieving the image_uri, the argument %s should not be a pipeline variable " + "(%s) since pipeline variables are only interpreted in the pipeline execution time." + % (name, type(val)) + ) if is_jumpstart_model_input(model_id, model_version): return artifacts._retrieve_image_uri( @@ -487,6 +490,9 @@ def get_training_image_uri( if image_uri: return image_uri + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, framework etc." + ) base_framework_version: Optional[str] = None if tensorflow_version is not None or pytorch_version is not None: diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 704e3385fd..a2c6da4bb7 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -527,10 +527,10 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: artifact should be repackaged into a new S3 object. (default: False). """ local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) + bucket = self.bucket or self.sagemaker_session.default_bucket() if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: self.uploaded_code = None elif not repack: - bucket = self.bucket or self.sagemaker_session.default_bucket() self.uploaded_code = fw_utils.tar_and_upload_dir( session=self.sagemaker_session.boto_session, bucket=bucket, @@ -557,6 +557,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: ) return self.sagemaker_session.context.need_runtime_repack.add(id(self)) + self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format( + bucket, key_prefix + ) # Add the uploaded_code and repacked_model_data to update the container env self.repacked_model_data = self.model_data self.uploaded_code = fw_utils.UploadedCode( @@ -567,7 +570,6 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: if local_code and self.model_data.startswith("file://"): repacked_model_data = self.model_data else: - bucket = self.bucket or self.sagemaker_session.default_bucket() repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) self.uploaded_code = fw_utils.UploadedCode( s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point) diff --git a/src/sagemaker/rl/estimator.py b/src/sagemaker/rl/estimator.py index 8d6a00b68e..b004dd87b8 100644 --- a/src/sagemaker/rl/estimator.py +++ b/src/sagemaker/rl/estimator.py @@ -282,6 +282,11 @@ def training_image_uri(self): """ if self.image_uri: return self.image_uri + + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, " + "framework etc." + ) return image_uris.retrieve( self._image_framework(), self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index c910e85f20..401ae04b23 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -24,6 +24,8 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import PipelineSession +logger = logging.getLogger(__name__) + class TensorFlowPredictor(Predictor): """A ``Predictor`` implementation for inference against TensorFlow Serving endpoints.""" @@ -363,13 +365,10 @@ def prepare_container_def( instance_type, accelerator_type, serverless_inference_config=serverless_inference_config ) env = self._get_container_env() + key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image_uri) + bucket = self.bucket or self.sagemaker_session.default_bucket() if self.entry_point and not is_pipeline_variable(self.model_data): - key_prefix = sagemaker.fw_utils.model_code_key_prefix( - self.key_prefix, self.name, image_uri - ) - - bucket = self.bucket or self.sagemaker_session.default_bucket() model_data = s3.s3_path_join("s3://", bucket, key_prefix, "model.tar.gz") sagemaker.utils.repack_model( @@ -385,6 +384,9 @@ def prepare_container_def( # model is not yet there, defer repacking to later during pipeline execution if isinstance(self.sagemaker_session, PipelineSession): self.sagemaker_session.context.need_runtime_repack.add(id(self)) + self.sagemaker_session.context.runtime_repack_output_prefix = "s3://{}/{}".format( + bucket, key_prefix + ) else: logging.warning( "The model_data is a Pipeline variable of type %s, " @@ -426,6 +428,10 @@ def _get_image_uri( if self.image_uri: return self.image_uri + logger.info( + "image_uri is not presented, retrieving image_uri based on instance_type, " + "framework etc." + ) return image_uris.retrieve( self._framework_name, region_name or self.sagemaker_session.boto_region_name, diff --git a/src/sagemaker/workflow/model_step.py b/src/sagemaker/workflow/model_step.py index e46fd71a84..6c261d1bdc 100644 --- a/src/sagemaker/workflow/model_step.py +++ b/src/sagemaker/workflow/model_step.py @@ -23,8 +23,6 @@ from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.steps import Step, CreateModelStep -NEED_RUNTIME_REPACK = "need_runtime_repack" - _CREATE_MODEL_RETRY_POLICIES = "create_model_retry_policies" _REGISTER_MODEL_RETRY_POLICIES = "register_model_retry_policies" _REPACK_MODEL_RETRY_POLICIES = "repack_model_retry_policies" @@ -155,6 +153,7 @@ def __init__( self._create_model_args = self.step_args.create_model_request self._register_model_args = self.step_args.create_model_package_request self._need_runtime_repack = self.step_args.need_runtime_repack + self._runtime_repack_output_prefix = self.step_args.runtime_repack_output_prefix self._assign_and_validate_retry_policies(retry_policies) if self._need_runtime_repack: @@ -268,6 +267,7 @@ def _append_repack_model_step(self): ), depends_on=self.depends_on, retry_policies=self._repack_model_retry_policies, + output_path=self._runtime_repack_output_prefix, ) self.steps.append(repack_model_step) diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 95fbd9371c..341e123be0 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -67,6 +67,7 @@ def __init__(self, model): self.create_model_package_request = None self.create_model_request = None self.need_runtime_repack = set() + self.runtime_repack_output_prefix = None class PipelineSession(Session): @@ -139,14 +140,14 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None else: self.context = _JobStepArguments(func_name, request) - def init_step_arguments(self, model): + def init_model_step_arguments(self, model): """Create a `_ModelStepArguments` (if not exist) as pipeline context Args: model (Model or PipelineModel): A `sagemaker.model.Model` or `sagemaker.pipeline.PipelineModel` instance """ - if not self._context or not isinstance(self._context, _ModelStepArguments): + if not isinstance(self._context, _ModelStepArguments): self._context = _ModelStepArguments(model) @@ -197,7 +198,7 @@ def wrapper(*args, **kwargs): UserWarning, ) if run_func.__name__ in ["register", "create"]: - self_instance.sagemaker_session.init_step_arguments(self_instance) + self_instance.sagemaker_session.init_model_step_arguments(self_instance) run_func(*args, **kwargs) context = self_instance.sagemaker_session.context self_instance.sagemaker_session.context = None diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index afe1e4eae1..a30ddd4dee 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -29,6 +29,8 @@ RequestType, ) +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection @@ -173,26 +175,26 @@ def override_pipeline_parameter_var(func): We should remove this decorator after the grace period. """ warning_msg_template = ( - "%s should not be a pipeline variable (%s). " - "The default_value of this Parameter object will be used to override it. " - "Please remove this pipeline variable and use python primitives instead." + "The input argument %s of function (%s) is a pipeline variable (%s), which is not allowed. " + "The default_value of this Parameter object will be used to override it." ) @wraps(func) def wrapper(*args, **kwargs): + func_name = "{}.{}".format(func.__module__, func.__name__) params = inspect.signature(func).parameters args = list(args) for i, (arg_name, _) in enumerate(params.items()): if i >= len(args): break if isinstance(args[i], Parameter): - logging.warning(warning_msg_template, arg_name, type(args[i])) + logger.warning(warning_msg_template, arg_name, func_name, type(args[i])) args[i] = args[i].default_value args = tuple(args) for arg_name, value in kwargs.items(): if isinstance(value, Parameter): - logging.warning(warning_msg_template, arg_name, type(value)) + logger.warning(warning_msg_template, arg_name, func_name, type(value)) kwargs[arg_name] = value.default_value return func(*args, **kwargs) diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index d5e21be1bf..31c518b100 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -836,6 +836,7 @@ def test_tensorflow_model_register_and_deploy_with_runtime_repack( sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), dependencies=[os.path.join(_TENSORFLOW_PATH, "dependency.py")], + code_location=f"s3://{pipeline_session.default_bucket()}/model-code", ) step_args = tf_model.register( content_types=["application/json"], diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index c167da6f47..ae37395b92 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -754,7 +754,7 @@ def test_retrieve_with_pipeline_variable(): kwargs["instance_type"] = Join(on="", values=["a", "b"]) with pytest.raises(Exception) as error: image_uris.retrieve(**kwargs) - assert "instance_type should not be a pipeline variable" in str(error.value) + assert "the argument instance_type should not be a pipeline variable" in str(error.value) # instance_type (ParameterString) is given as args rather than kwargs # which should not break anything diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 68961b355c..cfeb8d5a03 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -67,6 +67,8 @@ _DIR_NAME = "/opt/ml/model/code" _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") +_REPACK_OUTPUT_KEY_PREFIX = "code-output" +_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" @pytest.fixture @@ -688,6 +690,7 @@ def test_conditional_model_create_and_regis( entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, enable_network_isolation=True, + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -711,6 +714,7 @@ def test_conditional_model_create_and_regis( entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, framework_version="1.5.0", + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -742,6 +746,7 @@ def test_conditional_model_create_and_regis( image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", role=_ROLE, + code_location=_MODEL_CODE_LOCATION, ), 2, ), @@ -758,21 +763,45 @@ def test_conditional_model_create_and_regis( ], ) def test_create_model_among_different_model_types(test_input, pipeline_session, model_data_param): + def assert_test_result(steps: list): + # If expected_step_num is 2, it means a runtime repack step is appended + # If expected_step_num is 1, it means no runtime repack is needed + assert len(steps) == expected_step_num + if expected_step_num == 2: + assert steps[0]["Type"] == "Training" + if model.key_prefix == _REPACK_OUTPUT_KEY_PREFIX: + assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( + f"{_MODEL_CODE_LOCATION}/{model.name}" + ) + else: + assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( + f"s3://{_BUCKET}/{model.name}" + ) + model, expected_step_num = test_input model.sagemaker_session = pipeline_session model.model_data = model_data_param - step_args = model.create( + create_model_step_args = model.create( instance_type="c4.4xlarge", ) - model_steps = ModelStep( + create_model_steps = ModelStep( name="MyModelStep", - step_args=step_args, + step_args=create_model_step_args, ) - steps = model_steps.request_dicts() + assert_test_result(create_model_steps.request_dicts()) - # If expected_step_num is 2, it means a runtime repack step is appended - # If expected_step_num is 1, it means no runtime repack is needed - assert len(steps) == expected_step_num + register_model_step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_group_name="MyModelPackageGroup", + ) + register_model_steps = ModelStep( + name="MyModelStep", + step_args=register_model_step_args, + ) + assert_test_result(register_model_steps.request_dicts()) @pytest.mark.parametrize( From 75c485494247d0e024f5c0a887672c78aabf03ad Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Wed, 27 Jul 2022 14:04:47 -0400 Subject: [PATCH 136/526] feature: enhance-bucket-override-support (#3235) --- src/sagemaker/jumpstart/accessors.py | 19 +++ src/sagemaker/jumpstart/artifacts.py | 15 +- src/sagemaker/jumpstart/cache.py | 77 ++++++++-- src/sagemaker/jumpstart/constants.py | 6 + src/sagemaker/jumpstart/notebook_utils.py | 2 +- src/sagemaker/jumpstart/types.py | 4 +- .../sagemaker/jumpstart/test_accessors.py | 12 +- tests/unit/sagemaker/jumpstart/test_cache.py | 137 +++++++++++++++++- .../jumpstart/test_notebook_utils.py | 30 ++-- tests/unit/sagemaker/jumpstart/utils.py | 2 +- .../model_uris/jumpstart/test_common.py | 23 +++ .../script_uris/jumpstart/test_common.py | 26 ++++ 12 files changed, 313 insertions(+), 40 deletions(-) diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index fbdc0f5b56..e07564d362 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -13,6 +13,8 @@ """This module contains accessors related to SageMaker JumpStart.""" from __future__ import absolute_import from typing import Any, Dict, List, Optional + +from sagemaker.deprecations import deprecated from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs from sagemaker.jumpstart import cache from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME @@ -78,6 +80,22 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None: ) JumpStartModelsAccessor._curr_region = region + @staticmethod + def _get_manifest(region: str = JUMPSTART_DEFAULT_REGION_NAME) -> List[JumpStartModelHeader]: + """Return entire JumpStart models manifest. + + Raises: + ValueError: If region in `cache_kwargs` is inconsistent with `region` argument. + + Args: + region (str): Optional. The region to use for the cache. + """ + cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs( + JumpStartModelsAccessor._cache_kwargs, region + ) + JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs) + return JumpStartModelsAccessor._cache.get_manifest() # type: ignore + @staticmethod def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader: """Returns model header from JumpStart models cache. @@ -152,6 +170,7 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region) @staticmethod + @deprecated() def get_manifest( cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None ) -> List[JumpStartModelHeader]: diff --git a/src/sagemaker/jumpstart/artifacts.py b/src/sagemaker/jumpstart/artifacts.py index a61f46702f..cf63a46a7b 100644 --- a/src/sagemaker/jumpstart/artifacts.py +++ b/src/sagemaker/jumpstart/artifacts.py @@ -12,9 +12,12 @@ # language governing permissions and limitations under the License. """This module contains functions for obtaining JumpStart ECR and S3 URIs.""" from __future__ import absolute_import +import os from typing import Dict, Optional from sagemaker import image_uris from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE, JUMPSTART_DEFAULT_REGION_NAME, ) from sagemaker.jumpstart.enums import ( @@ -176,6 +179,8 @@ def _retrieve_model_uri( ): """Retrieves the model artifact S3 URI for the model matching the given arguments. + Optionally uses a bucket override specified by environment variable. + Args: model_id (str): JumpStart model ID of the JumpStart model for which to retrieve the model artifact S3 URI. @@ -217,7 +222,9 @@ def _retrieve_model_uri( elif model_scope == JumpStartScriptScope.TRAINING: model_artifact_key = model_specs.training_artifact_key - bucket = get_jumpstart_content_bucket(region) + bucket = os.environ.get( + ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE + ) or get_jumpstart_content_bucket(region) model_s3_uri = f"s3://{bucket}/{model_artifact_key}" @@ -234,6 +241,8 @@ def _retrieve_script_uri( ): """Retrieves the script S3 URI associated with the model matching the given arguments. + Optionally uses a bucket override specified by environment variable. + Args: model_id (str): JumpStart model ID of the JumpStart model for which to retrieve the script S3 URI. @@ -275,7 +284,9 @@ def _retrieve_script_uri( elif script_scope == JumpStartScriptScope.TRAINING: model_script_key = model_specs.training_script_key - bucket = get_jumpstart_content_bucket(region) + bucket = os.environ.get( + ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE + ) or get_jumpstart_content_bucket(region) script_s3_uri = f"s3://{bucket}/{model_script_key}" diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index ac1ed5a17f..202edff9ad 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -14,13 +14,16 @@ from __future__ import absolute_import import datetime from difflib import get_close_matches -from typing import List, Optional +import os +from typing import List, Optional, Tuple, Union import json import boto3 import botocore from packaging.version import Version from packaging.specifiers import SpecifierSet from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JUMPSTART_DEFAULT_REGION_NAME, ) @@ -90,7 +93,7 @@ def __init__( self._s3_cache = LRUCache[JumpStartCachedS3ContentKey, JumpStartCachedS3ContentValue]( max_cache_items=max_s3_cache_items, expiration_horizon=s3_cache_expiration_horizon, - retrieval_function=self._get_file_from_s3, + retrieval_function=self._retrieval_function, ) self._model_id_semantic_version_manifest_key_cache = LRUCache[ JumpStartVersionedModelId, JumpStartVersionedModelId @@ -235,7 +238,64 @@ def _get_manifest_key_from_model_id_semantic_version( raise KeyError(error_msg) - def _get_file_from_s3( + def _get_json_file_and_etag_from_s3(self, key: str) -> Tuple[Union[dict, list], str]: + """Returns json file from s3, along with its etag.""" + response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=key) + return json.loads(response["Body"].read().decode("utf-8")), response["ETag"] + + def _is_local_metadata_mode(self) -> bool: + """Returns True if the cache should use local metadata mode, based off env variables.""" + return (ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE]) + and ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE in os.environ + and os.path.isdir(os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE])) + + def _get_json_file( + self, + key: str, + filetype: JumpStartS3FileType + ) -> Tuple[Union[dict, list], Optional[str]]: + """Returns json file either from s3 or local file system. + + Returns etag along with json object for s3, or just the json + object and None when reading from the local file system. + """ + if self._is_local_metadata_mode(): + file_content, etag = self._get_json_file_from_local_override(key, filetype), None + else: + file_content, etag = self._get_json_file_and_etag_from_s3(key) + return file_content, etag + + def _get_json_md5_hash(self, key: str): + """Retrieves md5 object hash for s3 objects, using `s3.head_object`. + + Raises: + ValueError: if the cache should use local metadata mode. + """ + if self._is_local_metadata_mode(): + raise ValueError("Cannot get md5 hash of local file.") + return self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=key)["ETag"] + + def _get_json_file_from_local_override( + self, + key: str, + filetype: JumpStartS3FileType + ) -> Union[dict, list]: + """Reads json file from local filesystem and returns data.""" + if filetype == JumpStartS3FileType.MANIFEST: + metadata_local_root = ( + os.environ[ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE] + ) + elif filetype == JumpStartS3FileType.SPECS: + metadata_local_root = os.environ[ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE] + else: + raise ValueError(f"Unsupported file type for local override: {filetype}") + file_path = os.path.join(metadata_local_root, key) + with open(file_path, 'r') as f: + data = json.load(f) + return data + + def _retrieval_function( self, key: JumpStartCachedS3ContentKey, value: Optional[JumpStartCachedS3ContentValue], @@ -256,20 +316,17 @@ def _get_file_from_s3( file_type, s3_key = key.file_type, key.s3_key if file_type == JumpStartS3FileType.MANIFEST: - if value is not None: - etag = self._s3_client.head_object(Bucket=self.s3_bucket_name, Key=s3_key)["ETag"] + if value is not None and not self._is_local_metadata_mode(): + etag = self._get_json_md5_hash(s3_key) if etag == value.md5_hash: return value - response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) - formatted_body = json.loads(response["Body"].read().decode("utf-8")) - etag = response["ETag"] + formatted_body, etag = self._get_json_file(s3_key, file_type) return JumpStartCachedS3ContentValue( formatted_content=utils.get_formatted_manifest(formatted_body), md5_hash=etag, ) if file_type == JumpStartS3FileType.SPECS: - response = self._s3_client.get_object(Bucket=self.s3_bucket_name, Key=s3_key) - formatted_body = json.loads(response["Body"].read().decode("utf-8")) + formatted_body, _ = self._get_json_file(s3_key, file_type) return JumpStartCachedS3ContentValue( formatted_content=JumpStartModelSpecs(formatted_body) ) diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 2b0fb4ee12..7736487359 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -124,5 +124,11 @@ SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope) ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE = "AWS_JUMPSTART_CONTENT_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_MODEL_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE = "AWS_JUMPSTART_SCRIPT_BUCKET_OVERRIDE" +ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE = ( + "AWS_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE" +) +ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE = "AWS_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE" JUMPSTART_RESOURCE_BASE_NAME = "sagemaker-jumpstart" diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 09e812ee4d..773ea9df41 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -284,7 +284,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin if isinstance(filter, str): filter = Identity(filter) - models_manifest_list = accessors.JumpStartModelsAccessor.get_manifest(region=region) + models_manifest_list = accessors.JumpStartModelsAccessor._get_manifest(region=region) manifest_keys = set(models_manifest_list[0].__slots__) all_keys: Set[str] = set() diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 36604fccdc..5fd9b319f9 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -65,7 +65,7 @@ def __str__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__} + att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return f"{type(self).__name__}: {str(att_dict)}" def __repr__(self) -> str: @@ -75,7 +75,7 @@ def __repr__(self) -> str: {'content_bucket': 'bucket', 'region_name': 'us-west-2'}" """ - att_dict = {att: getattr(self, att) for att in self.__slots__} + att_dict = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)} return f"{type(self).__name__} at {hex(id(self))}: {str(att_dict)}" diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index b8ba98bf9c..2de0351103 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -16,6 +16,7 @@ import pytest from sagemaker.jumpstart import accessors +from tests.unit.sagemaker.jumpstart.constants import BASE_MANIFEST from tests.unit.sagemaker.jumpstart.utils import ( get_header_from_base_header, get_spec_from_base_spec, @@ -36,9 +37,12 @@ def test_jumpstart_sagemaker_settings(): reload(accessors) -@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_header", get_header_from_base_header) -@patch("sagemaker.jumpstart.cache.JumpStartModelsCache.get_specs", get_spec_from_base_spec) -def test_jumpstart_models_cache_get_fxs(): +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_fxs(mock_cache): + + mock_cache.get_manifest = Mock(return_value=BASE_MANIFEST) + mock_cache.get_header = Mock(side_effect=get_header_from_base_header) + mock_cache.get_specs = Mock(side_effect=get_spec_from_base_spec) assert get_header_from_base_header( region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" @@ -51,7 +55,7 @@ def test_jumpstart_models_cache_get_fxs(): region="us-west-2", model_id="pytorch-ic-mobilenet-v2", version="*" ) - assert len(accessors.JumpStartModelsAccessor.get_manifest()) > 0 + assert len(accessors.JumpStartModelsAccessor._get_manifest()) > 0 # necessary because accessors is a static module reload(accessors) diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index f87820114d..58a8e34d25 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -15,6 +15,7 @@ import datetime import io import json +from unittest.mock import Mock, call, mock_open from botocore.stub import Stubber import botocore @@ -23,13 +24,18 @@ from mock import patch from sagemaker.jumpstart.cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, JumpStartModelsCache +from sagemaker.jumpstart.constants import ( + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, +) from sagemaker.jumpstart.types import ( JumpStartModelHeader, + JumpStartModelSpecs, JumpStartVersionedModelId, ) from tests.unit.sagemaker.jumpstart.utils import ( get_spec_from_base_spec, - patched_get_file_from_s3, + patched_retrieval_function, ) from tests.unit.sagemaker.jumpstart.constants import ( @@ -38,7 +44,7 @@ ) -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_header(): @@ -582,7 +588,7 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client): mock_boto3_client.return_value.head_object.assert_not_called() -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -625,7 +631,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache(): cache.clear.assert_called_once() -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_get_full_manifest(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -634,7 +640,7 @@ def test_jumpstart_get_full_manifest(): raw_manifest == BASE_MANIFEST -@patch.object(JumpStartModelsCache, "_get_file_from_s3", patched_get_file_from_s3) +@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") def test_jumpstart_cache_get_specs(): cache = JumpStartModelsCache(s3_bucket_name="some_bucket") @@ -690,3 +696,124 @@ def test_jumpstart_cache_get_specs(): model_id=model_id, semantic_version_str="5.*", ) + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_header( + mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock +): + mocked_open.side_effect = mock_open(read_data=json.dumps(BASE_MANIFEST)) + mocked_is_dir.return_value = True + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert JumpStartModelHeader( + { + "model_id": "tensorflow-ic-imagenet-inception-v3-classification-4", + "version": "2.0.0", + "min_version": "2.49.0", + "spec_key": "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json", + } + ) == cache.get_header(model_id=model_id, semantic_version_str=version) + + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") + assert mocked_is_dir.call_count == 2 + mocked_open.assert_called_once_with( + "/some/directory/metadata/manifest/root/models_manifest.json", "r" + ) + mocked_get_json_file_and_etag_from_s3.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_specs( + mocked_open: Mock, mocked_is_dir: Mock, mocked_get_json_file_and_etag_from_s3: Mock +): + + mocked_open.side_effect = [ + mock_open(read_data=json.dumps(BASE_MANIFEST)).return_value, + mock_open(read_data=json.dumps(BASE_SPEC)).return_value, + ] + + mocked_is_dir.return_value = True + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + assert JumpStartModelSpecs(BASE_SPEC) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + mocked_is_dir.assert_any_call("/some/directory/metadata/specs/root") + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + assert mocked_is_dir.call_count == 4 + mocked_open.assert_any_call("/some/directory/metadata/manifest/root/models_manifest.json", "r") + mocked_open.assert_any_call( + "/some/directory/metadata/specs/root/community_models_specs/tensorflow-ic-imagenet-" + "inception-v3-classification-4/specs_v2.0.0.json", + "r", + ) + assert mocked_open.call_count == 2 + mocked_get_json_file_and_etag_from_s3.assert_not_called() + + +@patch.object(JumpStartModelsCache, "_get_json_file_and_etag_from_s3") +@patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/manifest/root", + ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE: "/some/directory/metadata/specs/root", + }, +) +@patch("sagemaker.jumpstart.cache.os.path.isdir") +@patch("builtins.open") +def test_jumpstart_local_metadata_override_specs_not_exist_both_directories( + mocked_open: Mock, + mocked_is_dir: Mock, + mocked_get_json_file_and_etag_from_s3: Mock, +): + model_id, version = "tensorflow-ic-imagenet-inception-v3-classification-4", "2.0.0" + + mocked_get_json_file_and_etag_from_s3.side_effect = [ + (BASE_MANIFEST, "blah1"), + (get_spec_from_base_spec(model_id=model_id, version=version).to_json(), "blah2"), + ] + + mocked_is_dir.side_effect = [False, False] + cache = JumpStartModelsCache(s3_bucket_name="some_bucket") + + assert get_spec_from_base_spec(model_id=model_id, version=version) == cache.get_specs( + model_id=model_id, semantic_version_str=version + ) + + mocked_is_dir.assert_any_call("/some/directory/metadata/manifest/root") + assert mocked_is_dir.call_count == 2 + mocked_open.assert_not_called() + mocked_get_json_file_and_etag_from_s3.assert_has_calls( + calls=[ + call("models_manifest.json"), + call( + "community_models_specs/tensorflow-ic-imagenet-inception-v3-classification-4/specs_v2.0.0.json" + ), + ] + ) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 76ae1072fd..3ac8973ad3 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -22,7 +22,7 @@ ) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_scripts( @@ -66,7 +66,7 @@ def test_list_jumpstart_scripts( assert patched_get_model_specs.call_count == len(PROTOTYPICAL_MODEL_SPECS_DICT) -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_tasks( @@ -106,7 +106,7 @@ def test_list_jumpstart_tasks( patched_get_model_specs.assert_not_called() -@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils._generate_jumpstart_model_versions") def test_list_jumpstart_frameworks( @@ -161,7 +161,7 @@ def test_list_jumpstart_frameworks( class ListJumpStartModels(TestCase): - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_simple_case( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -182,7 +182,7 @@ def test_list_jumpstart_models_simple_case( patched_get_manifest.assert_called() patched_get_model_specs.assert_not_called() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_script_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -232,7 +232,7 @@ def test_list_jumpstart_models_script_filter( assert patched_get_model_specs.call_count == manifest_length patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_task_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -287,7 +287,7 @@ def test_list_jumpstart_models_task_filter( patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_framework_filter( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -367,7 +367,7 @@ def test_list_jumpstart_models_framework_filter( patched_get_model_specs.assert_not_called() patched_get_manifest.assert_called_once() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_region( self, patched_get_model_specs: Mock, patched_get_manifest: Mock @@ -380,7 +380,7 @@ def test_list_jumpstart_models_region( patched_get_manifest.assert_called_once_with(region="some-region") - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") @patch("sagemaker.jumpstart.notebook_utils.get_sagemaker_version") def test_list_jumpstart_models_unsupported_models( @@ -412,7 +412,7 @@ def test_list_jumpstart_models_unsupported_models( assert [] != list_jumpstart_models("training_supported in [False, True]") patched_get_model_specs.assert_called() - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_old_models( self, @@ -483,7 +483,7 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True) - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_vulnerable_models( self, @@ -532,7 +532,7 @@ def vulnerable_training_model_spec(*args, **kwargs): assert patched_get_model_specs.call_count == 0 - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_deprecated_models( self, @@ -562,7 +562,7 @@ def deprecated_model_spec(*args, **kwargs): assert patched_get_model_specs.call_count == 0 - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_no_versions( self, @@ -587,7 +587,7 @@ def test_list_jumpstart_models_no_versions( assert list_jumpstart_models(list_versions=False) == all_model_ids - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_complex_queries( self, @@ -630,7 +630,7 @@ def test_list_jumpstart_models_complex_queries( ) ) == ["tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1"] - @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_manifest") + @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_list_jumpstart_models_multiple_level_index( self, diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index c7da962b49..7b1fc45aeb 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -131,7 +131,7 @@ def get_spec_from_base_spec( return JumpStartModelSpecs(spec) -def patched_get_file_from_s3( +def patched_retrieval_function( _modelCacheObj: JumpStartModelsCache, key: JumpStartCachedS3ContentKey, value: JumpStartCachedS3ContentValue, diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 699f5836f3..396132ae52 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -127,3 +127,26 @@ def test_jumpstart_common_model_uri( model_scope="training", model_id="pytorch-ic-mobilenet-v2", ) + + +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + sagemaker_constants.ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name" + }, +) +def test_jumpstart_artifact_bucket_override( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + uri = model_uris.retrieve( + model_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + assert uri == "s3://some-cool-bucket-name/pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz" diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 05d8368bf3..ca45b3729d 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -127,3 +127,29 @@ def test_jumpstart_common_script_uri( script_scope="training", model_id="pytorch-ic-mobilenet-v2", ) + + +@patch("sagemaker.jumpstart.artifacts.verify_model_region_and_return_specs") +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") +@patch.dict( + "sagemaker.jumpstart.cache.os.environ", + { + sagemaker_constants.ENV_VARIABLE_JUMPSTART_SCRIPT_ARTIFACT_BUCKET_OVERRIDE: "some-cool-bucket-name" + }, +) +def test_jumpstart_artifact_bucket_override( + patched_get_model_specs, patched_verify_model_region_and_return_specs +): + + patched_verify_model_region_and_return_specs.side_effect = verify_model_region_and_return_specs + patched_get_model_specs.side_effect = get_spec_from_base_spec + + uri = script_uris.retrieve( + script_scope="training", + model_id="pytorch-ic-mobilenet-v2", + model_version="*", + ) + assert ( + uri + == "s3://some-cool-bucket-name/source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz" + ) From c2d1cfdbc2ea4af7cd673843e9f138ca1e582f80 Mon Sep 17 00:00:00 2001 From: Haixin Wang <98612668+haixiw@users.noreply.github.com> Date: Wed, 27 Jul 2022 12:21:51 -0700 Subject: [PATCH 137/526] feature: Algorithms region launch on CGK (#3234) --- src/sagemaker/image_uri_config/blazingtext.json | 1 + .../image_uri_config/factorization-machines.json | 1 + src/sagemaker/image_uri_config/forecasting-deepar.json | 1 + src/sagemaker/image_uri_config/image-classification.json | 1 + src/sagemaker/image_uri_config/ipinsights.json | 1 + src/sagemaker/image_uri_config/kmeans.json | 1 + src/sagemaker/image_uri_config/knn.json | 1 + src/sagemaker/image_uri_config/linear-learner.json | 1 + src/sagemaker/image_uri_config/ntm.json | 1 + src/sagemaker/image_uri_config/object-detection.json | 1 + src/sagemaker/image_uri_config/object2vec.json | 1 + src/sagemaker/image_uri_config/pca.json | 1 + src/sagemaker/image_uri_config/randomcutforest.json | 1 + src/sagemaker/image_uri_config/semantic-segmentation.json | 1 + src/sagemaker/image_uri_config/seq2seq.json | 1 + src/sagemaker/image_uri_config/sklearn.json | 3 +++ src/sagemaker/image_uri_config/xgboost.json | 8 ++++++++ tests/unit/sagemaker/image_uris/test_algos.py | 2 ++ tests/unit/sagemaker/image_uris/test_sklearn.py | 1 + tests/unit/sagemaker/image_uris/test_xgboost.py | 2 ++ 20 files changed, 31 insertions(+) diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index 061d312375..c588d65c73 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -8,6 +8,7 @@ "ap-northeast-1": "501404015308", "ap-northeast-2": "306986355934", "ap-northeast-3": "867004704886", + "ap-southeast-3": "951798379941", "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index 1a07b50488..0f9930357f 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -8,6 +8,7 @@ "ap-northeast-1": "351501993468", "ap-northeast-2": "835164637446", "ap-northeast-3": "867004704886", + "ap-southeast-3": "951798379941", "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index b63cb1a99f..1acc96ed3e 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "514117268639", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 2373928397..44ccb3f08d 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -10,6 +10,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ap-northeast-3": "867004704886", "ca-central-1": "469771592824", "cn-north-1": "390948362332", diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 52a4995479..4e56c149dc 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 691ce50c45..952724ce11 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 38308e1d0e..79b239966d 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index f1cb97bf1a..bb027284ab 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index de02461f66..115264b346 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index 6876e97529..6a7ba03695 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 1b0ab26e51..39614d1273 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index b39217e573..5f87d8528c 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index a03160e9cf..ae7a3574be 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index ebe58a9053..866dd606b4 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index 9d4e209d6c..bb3daf93b6 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -11,6 +11,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index a9e5abf9f0..3ea77181ba 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -13,6 +13,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -44,6 +45,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -75,6 +77,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index ee0a6aff6b..fad22fb136 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -14,6 +14,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", @@ -46,6 +47,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -78,6 +80,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -110,6 +113,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -140,6 +144,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -170,6 +175,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -200,6 +206,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", @@ -230,6 +237,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index fbc81f7829..c1be52be42 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -57,6 +57,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "712309505854", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", @@ -142,6 +143,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index 58b668f3a2..d0fcbdb300 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -26,6 +26,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 0431afeef6..78ab7e10ee 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -24,6 +24,7 @@ "ap-south-1": "991648021394", "ap-southeast-1": "475088953585", "ap-southeast-2": "544295431143", + "ap-southeast-3": "951798379941", "ca-central-1": "469771592824", "cn-north-1": "390948362332", "cn-northwest-1": "387376663083", @@ -55,6 +56,7 @@ "ap-south-1": "720646828776", "ap-southeast-1": "121021644041", "ap-southeast-2": "783357654285", + "ap-southeast-3": "951798379941", "ca-central-1": "341280168497", "cn-north-1": "450853457545", "cn-northwest-1": "451049120500", From b98d6136359c2761d19ab00022ff8bd046bcac47 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 27 Jul 2022 21:10:53 +0000 Subject: [PATCH 138/526] prepare release v2.101.0 --- CHANGELOG.md | 20 ++++++++++++++++++++ VERSION | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f754b6fccb..3a151531af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## v2.101.0 (2022-07-27) + +### Features + + * Algorithms region launch on CGK + * enhance-bucket-override-support + * infer framework and version + * support clarify bias detection when facets not included + * Add CGK region to frameworks by DLC + +### Bug Fixes and Other Changes + + * Make repack step output path align with model repack path + * Support parameterized source code input for TrainingStep + +### Documentation Changes + + * heterogeneous cluster api doc fix + * smdmp v1.10 release note + ## v2.100.0 (2022-07-18) ### Features diff --git a/VERSION b/VERSION index 9d5c4490e9..e5e8eb7ce3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.100.1.dev0 +2.101.0 From 765b7ef46c6762144766de0c47cdc0b009bc2944 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 27 Jul 2022 21:10:54 +0000 Subject: [PATCH 139/526] update development version to v2.101.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index e5e8eb7ce3..7ca00c8338 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.101.0 +2.101.1.dev0 From 9b306b9cf21a78cb147e2f3424550e6a6f943a7f Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 27 Jul 2022 14:34:17 -0700 Subject: [PATCH 140/526] change: Add PipelineVariable annotation in amazon models (#3187) Co-authored-by: Dewen Qi --- src/sagemaker/amazon/factorization_machines.py | 14 +++++++++++++- src/sagemaker/amazon/ipinsights.py | 14 +++++++++++++- src/sagemaker/amazon/kmeans.py | 14 +++++++++++++- src/sagemaker/amazon/knn.py | 14 +++++++++++++- src/sagemaker/amazon/lda.py | 14 +++++++++++++- src/sagemaker/amazon/linear_learner.py | 14 +++++++++++++- src/sagemaker/amazon/ntm.py | 14 +++++++++++++- src/sagemaker/amazon/object2vec.py | 14 +++++++++++++- src/sagemaker/amazon/pca.py | 14 +++++++++++++- src/sagemaker/amazon/randomcutforest.py | 14 +++++++++++++- src/sagemaker/sparkml/model.py | 13 ++++++++++++- src/sagemaker/utils.py | 18 ++++++++++++++++++ tests/unit/test_utils.py | 12 ++++++++++++ 13 files changed, 172 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 6d4dedf86a..5e9c2098b9 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class FactorizationMachines(AmazonAlgorithmEstimatorBase): @@ -319,7 +323,13 @@ class FactorizationMachinesModel(Model): returns :class:`FactorizationMachinesPredictor`. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for FactorizationMachinesModel class. Args: @@ -343,6 +353,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=FactorizationMachines.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, FactorizationMachinesPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(FactorizationMachinesModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 8bc9103876..097f6b45dc 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -22,7 +24,9 @@ from sagemaker.model import Model from sagemaker.serializers import CSVSerializer from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class IPInsights(AmazonAlgorithmEstimatorBase): @@ -222,7 +226,13 @@ class IPInsightsModel(Model): Predictor that calculates anomaly scores for data points. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Creates object to get insights on S3 model data. Args: @@ -246,6 +256,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=IPInsights.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, IPInsightsPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(IPInsightsModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 286fe0c026..581e93e02a 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class KMeans(AmazonAlgorithmEstimatorBase): @@ -246,7 +250,13 @@ class KMeansModel(Model): Predictor to performs k-means cluster assignment. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for KMeansModel class. Args: @@ -270,6 +280,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=KMeans.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, KMeansPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(KMeansModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 10fe640b68..14ba404ebf 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class KNN(AmazonAlgorithmEstimatorBase): @@ -238,7 +242,13 @@ class KNNModel(Model): and returns :class:`KNNPredictor`. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Function to initialize KNNModel. Args: @@ -262,6 +272,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=KNN.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, KNNPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(KNNModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 2d7c4aa58b..4158b6cc27 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class LDA(AmazonAlgorithmEstimatorBase): @@ -220,7 +224,13 @@ class LDAModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for LDAModel class. Args: @@ -244,6 +254,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=LDA.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, LDAPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(LDAModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index e0a93c0120..d02ed2875f 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class LinearLearner(AmazonAlgorithmEstimatorBase): @@ -481,7 +485,13 @@ class LinearLearnerModel(Model): :class:`LinearLearnerPredictor` """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for LinearLearnerModel. Args: @@ -505,6 +515,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=LinearLearner.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, LinearLearnerPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(LinearLearnerModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 12f3fc635c..83c2f97348 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class NTM(AmazonAlgorithmEstimatorBase): @@ -249,7 +253,13 @@ class NTMModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for NTMModel class. Args: @@ -273,6 +283,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=NTM.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, NTMPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(NTMModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/object2vec.py b/src/sagemaker/amazon/object2vec.py index bd34eb7d19..1fbd846cbf 100644 --- a/src/sagemaker/amazon/object2vec.py +++ b/src/sagemaker/amazon/object2vec.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa @@ -20,7 +22,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable def _list_check_subset(valid_super_list): @@ -344,7 +348,13 @@ class Object2VecModel(Model): Predictor that calculates anomaly scores for datapoints. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for Object2VecModel class. Args: @@ -368,6 +378,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=Object2Vec.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, Predictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(Object2VecModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 93f8e25caa..e3127fd7a1 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class PCA(AmazonAlgorithmEstimatorBase): @@ -237,7 +241,13 @@ class PCAModel(Model): Predictor that transforms vectors to a lower-dimensional representation. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for PCAModel. Args: @@ -261,6 +271,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=PCA.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, PCAPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(PCAModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index a1c3e7d171..c38d75e3e4 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -13,6 +13,8 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker import image_uris from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase from sagemaker.amazon.common import RecordSerializer, RecordDeserializer @@ -21,7 +23,9 @@ from sagemaker.predictor import Predictor from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT +from sagemaker.workflow.entities import PipelineVariable class RandomCutForest(AmazonAlgorithmEstimatorBase): @@ -209,7 +213,13 @@ class RandomCutForestModel(Model): Predictor that calculates anomaly scores for datapoints. """ - def __init__(self, model_data, role, sagemaker_session=None, **kwargs): + def __init__( + self, + model_data: Union[str, PipelineVariable], + role: str, + sagemaker_session: Optional[Session] = None, + **kwargs + ): """Initialization for RandomCutForestModel class. Args: @@ -233,6 +243,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs): sagemaker_session.boto_region_name, version=RandomCutForest.repo_version, ) + pop_out_unused_kwarg("predictor_cls", kwargs, RandomCutForestPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(RandomCutForestModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/sparkml/model.py b/src/sagemaker/sparkml/model.py index f0c32fede8..527cae0957 100644 --- a/src/sagemaker/sparkml/model.py +++ b/src/sagemaker/sparkml/model.py @@ -13,8 +13,12 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional + from sagemaker import Model, Predictor, Session, image_uris from sagemaker.serializers import CSVSerializer +from sagemaker.utils import pop_out_unused_kwarg +from sagemaker.workflow.entities import PipelineVariable framework_name = "sparkml-serving" @@ -71,7 +75,12 @@ class SparkMLModel(Model): """ def __init__( - self, model_data, role=None, spark_version="2.4", sagemaker_session=None, **kwargs + self, + model_data: Union[str, PipelineVariable], + role: Optional[str] = None, + spark_version: str = "2.4", + sagemaker_session: Optional[Session] = None, + **kwargs, ): """Initialize a SparkMLModel. @@ -104,6 +113,8 @@ def __init__( # boto_region_name region_name = (sagemaker_session or Session()).boto_region_name image_uri = image_uris.retrieve(framework_name, region_name, version=spark_version) + pop_out_unused_kwarg("predictor_cls", kwargs, SparkMLPredictor.__name__) + pop_out_unused_kwarg("image_uri", kwargs, image_uri) super(SparkMLModel, self).__init__( image_uri, model_data, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 1998525a98..4365d22f2d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -27,6 +27,7 @@ import abc import uuid from datetime import datetime +from typing import Optional import botocore from six.moves.urllib import parse @@ -827,3 +828,20 @@ def construct_container_object( ) return obj + + +def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str] = None): + """Pop out the unused key-word argument and give a warning. + + Args: + arg_name (str): The name of the argument to be checked if it is unused. + kwargs (dict): The key-word argument dict. + override_val (str): The value used to override the unused argument (default: None). + """ + if arg_name not in kwargs: + return + warn_msg = "{} supplied in kwargs will be ignored".format(arg_name) + if override_val: + warn_msg += " and further overridden with {}.".format(override_val) + logging.warning(warn_msg) + kwargs.pop(arg_name) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4e6ba92730..5302e21fb8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -761,3 +761,15 @@ def test_partition_by_region(): assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov" assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso" assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b" + + +def test_pop_out_unused_kwarg(): + # The given arg_name is in kwargs + kwargs = dict(arg1=1, arg2=2) + sagemaker.utils.pop_out_unused_kwarg("arg1", kwargs) + assert "arg1" not in kwargs + + # The given arg_name is not in kwargs + kwargs = dict(arg1=1, arg2=2) + sagemaker.utils.pop_out_unused_kwarg("arg3", kwargs) + assert len(kwargs) == 2 From 62a72b383cf48e67f2bd8bdcc98e6b1ec99bdb78 Mon Sep 17 00:00:00 2001 From: Basil Beirouti Date: Wed, 27 Jul 2022 16:10:36 -0700 Subject: [PATCH 141/526] test: Vspecinteg2 (#3249) Co-authored-by: Basil Beirouti --- tests/data/marketplace/iris/Dockerfile | 23 +++ .../marketplace/iris/model-artifacts.joblib | Bin 0 -> 2522 bytes tests/data/marketplace/iris/scoring_logic.py | 108 +++++++++++++ tests/data/marketplace/iris/serve | 6 + tests/data/marketplace/iris/wsgi.py | 4 + tests/integ/test_marketplace.py | 142 +++++++++++++++++- tests/integ/test_multidatamodel.py | 3 +- tests/unit/sagemaker/model/test_model.py | 1 + 8 files changed, 284 insertions(+), 3 deletions(-) create mode 100644 tests/data/marketplace/iris/Dockerfile create mode 100644 tests/data/marketplace/iris/model-artifacts.joblib create mode 100644 tests/data/marketplace/iris/scoring_logic.py create mode 100644 tests/data/marketplace/iris/serve create mode 100644 tests/data/marketplace/iris/wsgi.py diff --git a/tests/data/marketplace/iris/Dockerfile b/tests/data/marketplace/iris/Dockerfile new file mode 100644 index 0000000000..3b0c2fab57 --- /dev/null +++ b/tests/data/marketplace/iris/Dockerfile @@ -0,0 +1,23 @@ +FROM public.ecr.aws/ubuntu/ubuntu:18.04 + +# Specify encoding +ENV LC_ALL=C.UTF-8 +ENV LANG=C.UTF-8 + +# Install python-pip +RUN apt-get update \ +&& apt-get install -y python3.6 python3-pip \ +&& ln -s /usr/bin/python3.6 /usr/bin/python \ +&& ln -s /usr/bin/pip3 /usr/bin/pip; + +# Install flask server +RUN pip install -U flask gunicorn joblib sklearn; + +#Copy scoring logic and model artifacts into the docker image +COPY scoring_logic.py /scoring_logic.py +COPY wsgi.py /wsgi.py +COPY model-artifacts.joblib /opt/ml/model/model-artifacts.joblib +COPY serve /opt/program/serve + +RUN chmod 755 /opt/program/serve +ENV PATH=/opt/program:${PATH} diff --git a/tests/data/marketplace/iris/model-artifacts.joblib b/tests/data/marketplace/iris/model-artifacts.joblib new file mode 100644 index 0000000000000000000000000000000000000000..97e2377c0dd2097ca4b89ec46213ae4c87ef14ea GIT binary patch literal 2522 zcmd5;&2Jk;6t`n1Stn_52q}@MpwLP+A2Oj5RSuPQQJ_vT$Px(JavII9$DM`s?y?^( zkw7v?O^Y<@g$YtQaezC2LNCCj2PCc>kg6WI^#}+7Z+72?^@buJ8Cm|#d-L9#_wie= z$IXMqnOe5wTKt~JS>$(dTdp>(iKxLXXEZGgV|pw#!;b zWvED9u50Br?p_fI9Y)Ma;2N48_VT=IHMyi?4>?QH2omngWi4wP>IZ{NvlKCZ5R7P? zu!Jj8lWRJxI|@@seGl?-BFVOCg>jO}XIHAWgzF zi9xJO1KUWITzxmNJ=gB~=_uTzq3b}qm0am-@s%iIdzT^>hULZ@acXD&O3r3pO3wQO zrfs47ayAw$glZbsVIM>T-R=4o><7sn=*0a!}q8lbwqEzm&vDQh+_5dqwetEC1vI-;n4c{+8SYp`{3X;EaxVU1=J>NY z!1Em5v6?8)BNprh@rS#2S1s}5@k`sEd~2S=^BjmZi}D;|>Ar9{{N1!}-P^o%?<;d- z!f$)(`d%#ijeAr36@Aa$AD9NlX(5*Kg!+CpPkjGAe10;EI&!o1sqZVr16RXZz<2vq&h zV7zVRyT=d + response = sm_client.list_model_packages( + MaxResults=10, + NameContains=MODEL_NAME, + SortBy="CreationTime", + SortOrder="Descending", + ) + + if len(response["ModelPackageSummaryList"]) > 0: + sm_client.delete_model_package(ModelPackageName=model_name) + + # assert that response is non-empty + assert len(response["ModelPackageSummaryList"]) > 0 + + @pytest.mark.skipif( tests.integ.test_region() in tests.integ.NO_MARKET_PLACE_REGIONS, reason="Marketplace is not available in {}".format(tests.integ.test_region()), diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index ea6db02d26..78ba62c3db 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -16,7 +16,6 @@ import os import requests -import botocore import docker import numpy import pytest @@ -116,7 +115,7 @@ def _delete_repository(ecr_client, repository_name): try: ecr_client.describe_repositories(repositoryNames=[repository_name]) ecr_client.delete_repository(repositoryName=repository_name, force=True) - except botocore.errorfactory.ResourceNotFoundException: + except ecr_client.exceptions.RepositoryNotFoundException: pass diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 30f4a20f49..0b04d3c8bc 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -742,6 +742,7 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema @patch("sagemaker.get_model_package_args") def test_register_calls_model_package_args(get_model_package_args, sagemaker_session): + """model.register() should pass the ValidationSpecification to get_model_package_args()""" source_dir = "s3://blah/blah/blah" t = Model( From 7da6d7b387fbccd25610fa4046ed0967eaf9a37b Mon Sep 17 00:00:00 2001 From: stacicho Date: Thu, 28 Jul 2022 09:15:10 -0700 Subject: [PATCH 142/526] fix: added more ml frameworks supported by SageMaker Workflows (#3263) --- doc/overview.rst | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/doc/overview.rst b/doc/overview.rst index 14b7d47cda..a6deb7b988 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -1713,11 +1713,15 @@ in the AWS documentation. SageMaker Workflow ****************** -You can use Apache Airflow to author, schedule and monitor SageMaker workflow. +You can use the following machine learning frameworks to author, schedule and monitor SageMaker workflow. -For more information, see `SageMaker Workflow in Apache Airflow`_. +.. toctree:: + :maxdepth: 2 -.. _SageMaker Workflow in Apache Airflow: https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/workflow/README.rst + workflows/airflow/index + workflows/step_functions/index + workflows/pipelines/index + workflows/lineage/index ************************************ SageMaker Model Building Pipeline From f065f5fd9732dd2970745ca73d857766028a90ca Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 28 Jul 2022 20:52:02 +0000 Subject: [PATCH 143/526] prepare release v2.101.1 --- CHANGELOG.md | 8 ++++++++ VERSION | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a151531af..b40005e7fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.101.1 (2022-07-28) + +### Bug Fixes and Other Changes + + * added more ml frameworks supported by SageMaker Workflows + * test: Vspecinteg2 + * Add PipelineVariable annotation in amazon models + ## v2.101.0 (2022-07-27) ### Features diff --git a/VERSION b/VERSION index 7ca00c8338..5e7a5b3e32 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.101.1.dev0 +2.101.1 From 8984f638d0da4cb4356f72d509a7bca535472540 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 28 Jul 2022 20:52:03 +0000 Subject: [PATCH 144/526] update development version to v2.101.2.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 5e7a5b3e32..bd7dda9e04 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.101.1 +2.101.2.dev0 From f2bc173eb966db1677e640f0bdaadf612efffca6 Mon Sep 17 00:00:00 2001 From: zaoliu-aws <101844763+zaoliu-aws@users.noreply.github.com> Date: Fri, 29 Jul 2022 14:49:31 -0700 Subject: [PATCH 145/526] feature: Add test for profiler enablement with debugger_hook false (#3256) Co-authored-by: Liu --- tests/integ/test_profiler.py | 41 ++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 0c074a9881..db47786d39 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -433,3 +433,44 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( job_description = mx.latest_training_job.describe() assert job_description["ProfilerConfig"]["S3OutputPath"] == mx.output_path + + +def test_mxnet_profiling_with_disable_debugger_hook( + sagemaker_session, + mxnet_training_latest_version, + mxnet_training_latest_py_version, + cpu_instance_type, +): + with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): + script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py") + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + framework_version=mxnet_training_latest_version, + py_version=mxnet_training_latest_py_version, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + debugger_hook_config=False, + ) + + train_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train" + ) + test_input = mx.sagemaker_session.upload_data( + path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test" + ) + + training_job_name = unique_name_from_base("test-profiler-mxnet-training") + mx.fit( + inputs={"train": train_input, "test": test_input}, + job_name=training_job_name, + wait=False, + ) + + job_description = mx.latest_training_job.describe() + # setting debugger_hook_config to false would not disable profiling + # https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-turn-off.html + assert job_description.get("ProfilingStatus") == "Enabled" From 3f7a7e80f98ee3dd803a00e6450bc0c6cb456709 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Mon, 1 Aug 2022 13:19:59 -0700 Subject: [PATCH 146/526] change: skip managed spot training mxnet nb (#3273) --- tests/scripts/run-notebook-test.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/scripts/run-notebook-test.sh b/tests/scripts/run-notebook-test.sh index db6c73250b..528961c338 100755 --- a/tests/scripts/run-notebook-test.sh +++ b/tests/scripts/run-notebook-test.sh @@ -124,7 +124,6 @@ echo "set SAGEMAKER_ROLE_ARN=$SAGEMAKER_ROLE_ARN" ./amazon-sagemaker-examples/advanced_functionality/tensorflow_iris_byom/tensorflow_BYOM_iris.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_highlevel/kmeans_mnist.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/1P_kmeans_lowlevel/kmeans_mnist_lowlevel.ipynb \ -./amazon-sagemaker-examples/sagemaker-python-sdk/managed_spot_training_mxnet/managed_spot_training_mxnet.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_gluon_sentiment/mxnet_sentiment_analysis_with_gluon.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/mxnet_onnx_export/mxnet_onnx_export.ipynb \ ./amazon-sagemaker-examples/sagemaker-python-sdk/scikit_learn_randomforest/Sklearn_on_SageMaker_end2end.ipynb \ From 2a400031398b23f8382561bdb06e9723fcef4bcf Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 1 Aug 2022 15:07:54 -0700 Subject: [PATCH 147/526] change: Add PipelineVariable annotation in framework models (#3188) Co-authored-by: Dewen Qi --- src/sagemaker/chainer/model.py | 74 ++++++++++++++----------- src/sagemaker/estimator.py | 6 +- src/sagemaker/huggingface/model.py | 88 +++++++++++++++++------------- src/sagemaker/multidatamodel.py | 20 ++++--- src/sagemaker/mxnet/model.py | 76 +++++++++++++++----------- src/sagemaker/parameter.py | 12 ++-- src/sagemaker/pipeline.py | 53 +++++++++--------- src/sagemaker/pytorch/model.py | 77 +++++++++++++++----------- src/sagemaker/sklearn/model.py | 74 ++++++++++++++----------- src/sagemaker/tensorflow/model.py | 66 ++++++++++++---------- src/sagemaker/tuner.py | 6 +- src/sagemaker/utils.py | 11 ++++ src/sagemaker/xgboost/model.py | 74 ++++++++++++++----------- tests/unit/test_utils.py | 16 +++++- 14 files changed, 372 insertions(+), 281 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 3f22e22d5d..1986febaaf 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,19 +14,25 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.chainer import defaults from sagemaker.deserializers import NumpyDeserializer from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -75,14 +81,14 @@ class ChainerModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - image_uri=None, - framework_version=None, - py_version=None, - predictor_cls=ChainerPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + image_uri: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[str] = None, + py_version: Optional[str] = None, + predictor_cls: callable = ChainerPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an ChainerModel. @@ -142,27 +148,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances, - transform_instances, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -218,6 +224,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(ChainerModel, self).register( content_types, response_types, @@ -236,7 +244,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -282,7 +290,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def(deploy_image, self.model_data, deploy_env) def serving_image_uri( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 1ab122b2e0..dee102999b 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -76,6 +76,7 @@ build_dict, get_config_value, name_from_base, + to_string, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -1947,10 +1948,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): current_hyperparameters = estimator.hyperparameters() if current_hyperparameters is not None: - hyperparameters = { - str(k): (v.to_string() if is_pipeline_variable(v) else str(v)) - for (k, v) in current_hyperparameters.items() - } + hyperparameters = {str(k): to_string(v) for (k, v) in current_hyperparameters.items()} train_args = config.copy() train_args["input_mode"] = estimator.input_mode diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 04af57b566..6f810dc5e2 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,18 +14,24 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import JSONDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.session import Session +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel): def __init__( self, - role, - model_data=None, - entry_point=None, - transformers_version=None, - tensorflow_version=None, - pytorch_version=None, - py_version=None, - image_uri=None, - predictor_cls=HuggingFacePredictor, - model_server_workers=None, + role: str, + model_data: Optional[Union[str, PipelineVariable]] = None, + entry_point: Optional[str] = None, + transformers_version: Optional[str] = None, + tensorflow_version: Optional[str] = None, + pytorch_version: Optional[str] = None, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = HuggingFacePredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): """Initialize a HuggingFaceModel. @@ -299,27 +305,27 @@ def deploy( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -377,6 +383,13 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = ( + framework + or fetch_framework_and_framework_version( + self.tensorflow_version, self.pytorch_version + )[0] + ).upper() return super(HuggingFaceModel, self).register( content_types, response_types, @@ -395,12 +408,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=( - framework - or fetch_framework_and_framework_version( - self.tensorflow_version, self.pytorch_version - )[0] - ).upper(), + framework=framework, framework_version=framework_version or fetch_framework_and_framework_version(self.tensorflow_version, self.pytorch_version)[ 1 @@ -449,7 +457,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index a3cd17cd8c..d90a5ca76f 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import os +from typing import Union, Optional from six.moves.urllib.parse import urlparse @@ -22,6 +23,8 @@ from sagemaker.deprecations import removed_kwargs from sagemaker.model import Model from sagemaker.session import Session +from sagemaker.utils import pop_out_unused_kwarg +from sagemaker.workflow.entities import PipelineVariable MULTI_MODEL_CONTAINER_MODE = "MultiModel" @@ -34,12 +37,12 @@ class MultiDataModel(Model): def __init__( self, - name, - model_data_prefix, - model=None, - image_uri=None, - role=None, - sagemaker_session=None, + name: str, + model_data_prefix: str, + model: Optional[Model] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + sagemaker_session: Optional[Session] = None, **kwargs, ): """Initialize a ``MultiDataModel``. @@ -106,6 +109,7 @@ def __init__( # Set the ``Model`` parameters if the model parameter is not specified if not self.model: + pop_out_unused_kwarg("model_data", kwargs, self.model_data_prefix) super(MultiDataModel, self).__init__( image_uri, self.model_data_prefix, @@ -115,7 +119,9 @@ def __init__( **kwargs, ) - def prepare_container_def(self, instance_type=None, accelerator_type=None): + def prepare_container_def( + self, instance_type=None, accelerator_type=None, serverless_inference_config=None + ): """Return a container definition set. Definition set includes MultiModel mode, model data and other parameters diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 4aaf6a8acc..f2e18c009e 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,21 +14,27 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import packaging.version import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import JSONDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.mxnet import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -77,14 +83,14 @@ class MXNetModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version=None, - image_uri=None, - predictor_cls=MXNetPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str = _LOWEST_MMS_VERSION, + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = MXNetPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an MXNetModel. @@ -102,7 +108,7 @@ def __init__( hosting. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. framework_version (str): MXNet version you want to use for executing - your model training code. Defaults to ``None``. Required unless + your model training code. Defaults to ``1.4.0``. Required unless ``image_uri`` is provided. py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless @@ -144,27 +150,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -220,6 +226,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(MXNetModel, self).register( content_types, response_types, @@ -238,7 +246,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -286,7 +294,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 79bbc62da2..b44e6f9ef2 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -16,8 +16,8 @@ import json from typing import Union -from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import to_string class ParameterRange(object): @@ -78,12 +78,8 @@ def as_tuning_range(self, name): """ return { "Name": name, - "MinValue": str(self.min_value) - if not is_pipeline_variable(self.min_value) - else self.min_value.to_string(), - "MaxValue": str(self.max_value) - if not is_pipeline_variable(self.max_value) - else self.max_value.to_string(), + "MinValue": to_string(self.min_value), + "MaxValue": to_string(self.max_value), "ScalingType": self.scaling_type, } @@ -117,7 +113,7 @@ def __init__(self, values): # pylint: disable=super-init-not-called This input will be converted into a list of strings. """ values = values if isinstance(values, list) else [values] - self.values = [str(v) if not is_pipeline_variable(v) else v.to_string() for v in values] + self.values = [to_string(v) for v in values] def as_tuning_range(self, name): """Represent the parameter range as a dictionary. diff --git a/src/sagemaker/pipeline.py b/src/sagemaker/pipeline.py index 5047e6351a..f7c1bded9a 100644 --- a/src/sagemaker/pipeline.py +++ b/src/sagemaker/pipeline.py @@ -13,10 +13,10 @@ """Placeholder docstring""" from __future__ import absolute_import -from typing import Optional, Dict +from typing import Optional, Dict, List, Union import sagemaker -from sagemaker import ModelMetrics +from sagemaker import ModelMetrics, Model from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.metadata_properties import MetadataProperties from sagemaker.session import Session @@ -25,6 +25,7 @@ update_container_with_inference_params, ) from sagemaker.transformer import Transformer +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -36,13 +37,13 @@ class PipelineModel(object): def __init__( self, - models, - role, - predictor_cls=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, + models: List[Model], + role: str, + predictor_cls: Optional[callable] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, ): """Initialize a SageMaker `Model` instance. @@ -267,27 +268,27 @@ def _create_sagemaker_pipeline_model(self, instance_type): @runnable_by_pipeline def register( self, - content_types: list, - response_types: list, - inference_instances: Optional[list] = None, - transform_instances: Optional[list] = None, - model_package_name: Optional[str] = None, - model_package_group_name: Optional[str] = None, - image_uri: Optional[str] = None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, model_metrics: Optional[ModelMetrics] = None, metadata_properties: Optional[MetadataProperties] = None, marketplace_cert: bool = False, - approval_status: Optional[str] = None, + approval_status: Optional[Union[str, PipelineVariable]] = None, description: Optional[str] = None, drift_check_baselines: Optional[DriftCheckBaselines] = None, - customer_metadata_properties: Optional[Dict[str, str]] = None, - domain: Optional[str] = None, - sample_payload_url: Optional[str] = None, - task: Optional[str] = None, - framework: Optional[str] = None, - framework_version: Optional[str] = None, - nearest_model_name: Optional[str] = None, - data_input_configuration: Optional[str] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -345,7 +346,7 @@ def register( framework_version=framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, - container_def=container_def, + container_list=container_def, ) else: container_def = [ diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fcbfd1da84..a16fc4d5e2 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,20 +14,27 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict + import packaging.version import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import NumpyDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, python_deprecation_warning, validate_version_or_image_args, ) +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -77,14 +84,14 @@ class PyTorchModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version=None, - image_uri=None, - predictor_cls=PyTorchPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str = "1.3", + py_version: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = PyTorchPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize a PyTorchModel. @@ -102,7 +109,7 @@ def __init__( hosting. If ``source_dir`` is specified, then ``entry_point`` must point to a file located at the root of ``source_dir``. framework_version (str): PyTorch version you want to use for - executing your model training code. Defaults to None. Required + executing your model training code. Defaults to 1.3. Required unless ``image_uri`` is provided. py_version (str): Python version you want to use for executing your model training code. Defaults to ``None``. Required unless @@ -145,27 +152,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -221,6 +228,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(PyTorchModel, self).register( content_types, response_types, @@ -239,7 +248,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -285,7 +294,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) return sagemaker.container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 70ea22908e..5bb469991a 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,15 +14,21 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import NumpyDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import NumpySerializer from sagemaker.sklearn import defaults +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -71,14 +77,14 @@ class SKLearnModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version=None, - py_version="py3", - image_uri=None, - predictor_cls=SKLearnPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: Optional[str] = None, + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + predictor_cls: callable = SKLearnPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an SKLearnModel. @@ -139,27 +145,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -215,6 +221,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(SKLearnModel, self).register( content_types, response_types, @@ -233,7 +241,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -274,7 +282,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) model_data_uri = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 401ae04b23..82885995b7 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,14 +14,18 @@ from __future__ import absolute_import import logging +from typing import Union, Optional, List, Dict import sagemaker -from sagemaker import image_uris, s3 +from sagemaker import image_uris, s3, ModelMetrics from sagemaker.deserializers import JSONDeserializer from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import PipelineSession logger = logging.getLogger(__name__) @@ -126,13 +130,13 @@ class TensorFlowModel(sagemaker.model.FrameworkModel): def __init__( self, - model_data, - role, - entry_point=None, - image_uri=None, - framework_version=None, - container_log_level=None, - predictor_cls=TensorFlowPredictor, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[str] = None, + container_log_level: Optional[int] = None, + predictor_cls: callable = TensorFlowPredictor, **kwargs, ): """Initialize a Model. @@ -193,27 +197,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -269,6 +273,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(TensorFlowModel, self).register( content_types, response_types, @@ -287,7 +293,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 58c875f8d9..0440cee3b8 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -44,8 +44,7 @@ from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.session import Session -from sagemaker.utils import base_from_name, base_name_from_image, name_from_base -from sagemaker.workflow import is_pipeline_variable +from sagemaker.utils import base_from_name, base_name_from_image, name_from_base, to_string AMAZON_ESTIMATOR_MODULE = "sagemaker" AMAZON_ESTIMATOR_CLS_NAMES = { @@ -414,8 +413,7 @@ def _prepare_static_hyperparameters( """Prepare static hyperparameters for one estimator before tuning.""" # Remove any hyperparameter that will be tuned static_hyperparameters = { - str(k): str(v) if not is_pipeline_variable(v) else v.to_string() - for (k, v) in estimator.hyperparameters().items() + str(k): to_string(v) for (k, v) in estimator.hyperparameters().items() } for hyperparameter_name in hyperparameter_ranges.keys(): static_hyperparameters.pop(hyperparameter_name, None) diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 4365d22f2d..d71b8e1433 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -845,3 +845,14 @@ def pop_out_unused_kwarg(arg_name: str, kwargs: dict, override_val: Optional[str warn_msg += " and further overridden with {}.".format(override_val) logging.warning(warn_msg) kwargs.pop(arg_name) + + +def to_string(obj: object): + """Convert an object to string + + This helper function handles converting PipelineVariable object to string as well + + Args: + obj (object): The object to be converted + """ + return obj.to_string() if is_pipeline_variable(obj) else str(obj) diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 6e56230234..5279c07c50 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,14 +14,20 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, List, Dict import sagemaker -from sagemaker import image_uris +from sagemaker import image_uris, ModelMetrics from sagemaker.deserializers import CSVDeserializer +from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix +from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor from sagemaker.serializers import LibSVMSerializer +from sagemaker.utils import to_string +from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.xgboost.defaults import XGBOOST_NAME from sagemaker.xgboost.utils import validate_py_version, validate_framework_version @@ -70,14 +76,14 @@ class XGBoostModel(FrameworkModel): def __init__( self, - model_data, - role, - entry_point, - framework_version, - image_uri=None, - py_version="py3", - predictor_cls=XGBoostPredictor, - model_server_workers=None, + model_data: Union[str, PipelineVariable], + role: str, + entry_point: str, + framework_version: str, + image_uri: Optional[Union[str, PipelineVariable]] = None, + py_version: str = "py3", + predictor_cls: callable = XGBoostPredictor, + model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): """Initialize an XGBoostModel. @@ -126,27 +132,27 @@ def __init__( def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - domain=None, - sample_payload_url=None, - task=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -202,6 +208,8 @@ def register( region_name=self.sagemaker_session.boto_session.region_name, instance_type=instance_type, ) + if not is_pipeline_variable(framework): + framework = (framework or self._framework_name).upper() return super(XGBoostModel, self).register( content_types, response_types, @@ -220,7 +228,7 @@ def register( domain=domain, sample_payload_url=sample_payload_url, task=task, - framework=(framework or self._framework_name).upper(), + framework=framework, framework_version=framework_version or self.framework_version, nearest_model_name=nearest_model_name, data_input_configuration=data_input_configuration, @@ -259,7 +267,9 @@ def prepare_container_def( deploy_env.update(self._script_mode_env_vars()) if self.model_server_workers: - deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers) + deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = to_string( + self.model_server_workers + ) model_data = ( self.repacked_model_data if self.enable_network_isolation() else self.model_data ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 5302e21fb8..b0b5045b94 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,7 +30,8 @@ import sagemaker from sagemaker.session_settings import SessionSettings from tests.unit.sagemaker.workflow.helpers import CustomStep -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.parameters import ParameterString, ParameterInteger + BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission" @@ -773,3 +774,16 @@ def test_pop_out_unused_kwarg(): kwargs = dict(arg1=1, arg2=2) sagemaker.utils.pop_out_unused_kwarg("arg3", kwargs) assert len(kwargs) == 2 + + +def test_to_string(): + var = 1 + assert sagemaker.utils.to_string(var) == "1" + + var = ParameterInteger(name="MyInt") + assert sagemaker.utils.to_string(var).expr == { + "Std:Join": { + "On": "", + "Values": [{"Get": "Parameters.MyInt"}], + }, + } From 1df557faeb35e6ee944faaa0dfc44cd04eb16b16 Mon Sep 17 00:00:00 2001 From: Vishwa Karia <45134824+vishwakaria@users.noreply.github.com> Date: Mon, 1 Aug 2022 15:38:22 -0700 Subject: [PATCH 148/526] feature: Add PyTorch DDP distribution support (#3270) Co-authored-by: Ubuntu --- doc/frameworks/pytorch/using_pytorch.rst | 91 ++++++++- src/sagemaker/fw_utils.py | 82 ++++++++ src/sagemaker/pytorch/estimator.py | 39 +++- tests/conftest.py | 12 ++ tests/data/_repack_model.py | 110 ++++++++++ tests/data/pytorch_ddp/mnist_pt.py | 246 +++++++++++++++++++++++ tests/integ/test_pytorchddp.py | 53 +++++ tests/unit/test_fw_utils.py | 62 ++++++ tests/unit/test_pytorch.py | 41 +++- 9 files changed, 722 insertions(+), 14 deletions(-) create mode 100644 tests/data/_repack_model.py create mode 100644 tests/data/pytorch_ddp/mnist_pt.py create mode 100644 tests/integ/test_pytorchddp.py diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index 9d4a4de3de..52720fe12b 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -200,17 +200,32 @@ fit Optional Arguments Distributed PyTorch Training ============================ -You can run a multi-machine, distributed PyTorch training using the PyTorch Estimator. By default, PyTorch objects will -submit single-machine training jobs to SageMaker. If you set ``instance_count`` to be greater than one, multi-machine -training jobs will be launched when ``fit`` is called. When you run multi-machine training, SageMaker will import your -training script and run it on each host in the cluster. +SageMaker supports the `PyTorch DistributedDataParallel (DDP) +`_ +package. You simply need to check the variables in your training script, +such as the world size and the rank of the current host, when initializing +process groups for distributed training. +And then, launch the training job using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class +with the ``pytorchddp`` option as the distribution strategy. -To initialize distributed training in your script you would call ``dist.init_process_group`` providing desired backend -and rank and setting 'WORLD_SIZE' environment variable similar to how you would do it outside of SageMaker using -environment variable initialization: +.. note:: + + This PyTorch DDP support is available + in the SageMaker PyTorch Deep Learning Containers v1.12 and later. + +Adapt Your Training Script +-------------------------- + +To initialize distributed training in your script, call +`torch.distributed.init_process_group +`_ +with the desired backend and the rank of the current host. .. code:: python + import torch.distributed as dist + if args.distributed: # Initialize the distributed environment. world_size = len(args.hosts) @@ -218,11 +233,65 @@ environment variable initialization: host_rank = args.hosts.index(args.current_host) dist.init_process_group(backend=args.backend, rank=host_rank) -SageMaker sets 'MASTER_ADDR' and 'MASTER_PORT' environment variables for you, but you can overwrite them. +SageMaker sets ``'MASTER_ADDR'`` and ``'MASTER_PORT'`` environment variables for you, +but you can also overwrite them. + +**Supported backends:** + +- ``gloo`` and ``tcp`` for CPU instances +- ``gloo`` and ``nccl`` for GPU instances + +Launching a Distributed Training Job +------------------------------------ + +You can run multi-node distributed PyTorch training jobs using the +:class:`sagemaker.pytorch.estimator.PyTorch` estimator class. +With ``instance_count=1``, the estimator submits a +single-node training job to SageMaker; with ``instance_count`` greater +than one, a multi-node training job is launched. + +To run a distributed training script that adopts +the `PyTorch DistributedDataParallel (DDP) package +`_, +choose the ``pytorchddp`` as the distributed training option in the ``PyTorch`` estimator. + +With the ``pytorchddp`` option, the SageMaker PyTorch estimator runs a SageMaker +training container for PyTorch, sets up the environment for MPI, and launches +the training job using the ``mpirun`` command on each worker with the given information +during the PyTorch DDP initialization. + +.. note:: + + The SageMaker PyTorch estimator doesn’t use ``torchrun`` for distributed training. + +For more information about setting up PyTorch DDP in your training script, +see `Getting Started with Distributed Data Parallel +`_ in the +PyTorch documentation. + +The following example shows how to run a PyTorch DDP training in SageMaker +using two ``ml.p4d.24xlarge`` instances: + +.. code:: python + + from sagemaker.pytorch import PyTorch + + pt_estimator = PyTorch( + entry_point="train_ptddp.py", + role="SageMakerRole", + framework_version="1.12.0", + py_version="py38", + instance_count=2, + instance_type="ml.p4d.24xlarge", + distribution={ + "pytorchddp": { + "enabled": True + } + } + ) + + pt_estimator.fit("s3://bucket/path/to/training/data") -Supported backends: -- `gloo` and `tcp` for cpu instances -- `gloo` and `nccl` for gpu instances ********************* diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 40787d4440..ef99454a45 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -103,6 +103,17 @@ "1.11.0", ], } + +PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS = [ + "1.10", + "1.10.0", + "1.10.2", + "1.11", + "1.11.0", + "1.12", + "1.12.0", +] + SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -728,6 +739,13 @@ def validate_distribution( distribution=distribution, image_uri=image_uri, ) + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) @@ -747,12 +765,76 @@ def validate_distribution( distribution=distribution, image_uri=image_uri, ) + validate_pytorch_distribution( + distribution=distribution, + framework_name=framework_name, + framework_version=framework_version, + py_version=py_version, + image_uri=image_uri, + ) warn_if_parameter_server_with_multi_gpu( training_instance_type=instance_type, distribution=distribution ) return distribution +def validate_pytorch_distribution( + distribution, framework_name, framework_version, py_version, image_uri +): + """Check if pytorch distribution strategy is correctly invoked by the user. + + Args: + distribution (dict): A dictionary with information to enable distributed training. + (Defaults to None if distributed training is not enabled.) For example: + + .. code:: python + + { + "pytorchddp": { + "enabled": True + } + } + framework_name (str): A string representing the name of framework selected. + framework_version (str): A string representing the framework version selected. + py_version (str): A string representing the python version selected. + image_uri (str): A string representing a Docker image URI. + + Raises: + ValueError: if + `py_version` is not python3 or + `framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS + """ + if framework_name and framework_name != "pytorch": + # We need to validate only for PyTorch framework + return + + pytorch_ddp_enabled = False + if "pytorchddp" in distribution: + pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + if not pytorch_ddp_enabled: + # Distribution strategy other than pytorchddp is selected + return + + err_msg = "" + if not image_uri: + # ignore framework_version and py_version if image_uri is set + # in case image_uri is not set, then both are mandatory + if framework_version not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + " pytorchddp.\n" + "Please specify one of the supported framework versions:" + f" {PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS} \n" + ) + if "py3" not in py_version: + err_msg += ( + f"Provided py_version {py_version} is not supported by pytorchddp.\n" + "Please specify py_version>=py3" + ) + if err_msg: + raise ValueError(err_msg) + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 07554ca798..153d4656d4 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -38,6 +38,8 @@ class PyTorch(Framework): """Handle end-to-end training and deployment of custom PyTorch code.""" _framework_name = "pytorch" + LAUNCH_PYTORCH_DDP_ENV_NAME = "sagemaker_pytorch_ddp_enabled" + INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" def __init__( self, @@ -153,6 +155,19 @@ def __init__( To find a complete list of parameters for SageMaker model parallelism, see :ref:`sm-sdk-modelparallel-general`. + **To enable PyTorch DDP:** + + .. code:: python + + { + "pytorchddp": { + "enabled": True + } + } + + To learn more, see `Distributed PyTorch Training + `_. + **To enable MPI:** .. code:: python @@ -217,10 +232,32 @@ def __init__( self.distribution = distribution or {} + def _pytorch_distribution_configuration(self, distribution): + """Returns a dict of distribution config for PyTorch training + + Args: + distribution (dict): A dictionary with information on how to run distributed training. + Returns: + dict containing Pytorch DDP config + """ + distribution_config = {} + pytorch_ddp_enabled = False + if "pytorchddp" in distribution: + pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False) + + if pytorch_ddp_enabled: + distribution_config[self.LAUNCH_PYTORCH_DDP_ENV_NAME] = pytorch_ddp_enabled + if self.instance_type is not None: + distribution_config[self.INSTANCE_TYPE_ENV_NAME] = self.instance_type + else: + distribution_config = self._distribution_configuration(distribution=distribution) + + return distribution_config + def hyperparameters(self): """Return hyperparameters used by your custom PyTorch code during model training.""" hyperparameters = super(PyTorch, self).hyperparameters() - additional_hyperparameters = self._distribution_configuration( + additional_hyperparameters = self._pytorch_distribution_configuration( distribution=self.distribution ) hyperparameters.update( diff --git a/tests/conftest.py b/tests/conftest.py index 8ccf443133..25f594a74b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -411,6 +411,18 @@ def tf_full_py_version(tf_full_version): return "py39" +@pytest.fixture(scope="module") +def pytorch_ddp_py_version(): + return "py3" + + +@pytest.fixture( + scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] +) +def pytorch_ddp_framework_version(request): + return request.param + + @pytest.fixture(scope="session") def cpu_instance_type(sagemaker_session, request): region = sagemaker_session.boto_session.region_name diff --git a/tests/data/_repack_model.py b/tests/data/_repack_model.py new file mode 100644 index 0000000000..3cfa6760b3 --- /dev/null +++ b/tests/data/_repack_model.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Repack model script for training jobs to inject entry points""" +from __future__ import absolute_import + +import argparse +import os +import shutil +import tarfile +import tempfile + +# Repack Model +# The following script is run via a training job which takes an existing model and a custom +# entry point script as arguments. The script creates a new model archive with the custom +# entry point in the "code" directory along with the existing model. Subsequently, when the model +# is unpacked for inference, the custom entry point will be used. +# Reference: https://docs.aws.amazon.com/sagemaker/latest/dg/amazon-sagemaker-toolkits.html + +# distutils.dir_util.copy_tree works way better than the half-baked +# shutil.copytree which bombs on previously existing target dirs... +# alas ... https://bugs.python.org/issue10948 +# we'll go ahead and use the copy_tree function anyways because this +# repacking is some short-lived hackery, right?? +from distutils.dir_util import copy_tree + + +def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover + """Repack custom dependencies and code into an existing model TAR archive + + Args: + inference_script (str): The path to the custom entry point. + model_archive (str): The name or path (e.g. s3 uri) of the model TAR archive. + dependencies (str): A space-delimited string of paths to custom dependencies. + source_dir (str): The path to a custom source directory. + """ + + # the data directory contains a model archive generated by a previous training job + data_directory = "/opt/ml/input/data/training" + model_path = os.path.join(data_directory, model_archive.split("/")[-1]) + + # create a temporary directory + with tempfile.TemporaryDirectory() as tmp: + local_path = os.path.join(tmp, "local.tar.gz") + # copy the previous training job's model archive to the temporary directory + shutil.copy2(model_path, local_path) + src_dir = os.path.join(tmp, "src") + # create the "code" directory which will contain the inference script + code_dir = os.path.join(src_dir, "code") + os.makedirs(code_dir) + # extract the contents of the previous training job's model archive to the "src" + # directory of this training job + with tarfile.open(name=local_path, mode="r:gz") as tf: + tf.extractall(path=src_dir) + + if source_dir: + # copy /opt/ml/code to code/ + if os.path.exists(code_dir): + shutil.rmtree(code_dir) + shutil.copytree("/opt/ml/code", code_dir) + else: + # copy the custom inference script to code/ + entry_point = os.path.join("/opt/ml/code", inference_script) + shutil.copy2(entry_point, os.path.join(code_dir, inference_script)) + + # copy any dependencies to code/lib/ + if dependencies: + for dependency in dependencies.split(" "): + actual_dependency_path = os.path.join("/opt/ml/code", dependency) + lib_dir = os.path.join(code_dir, "lib") + if not os.path.exists(lib_dir): + os.mkdir(lib_dir) + if os.path.isfile(actual_dependency_path): + shutil.copy2(actual_dependency_path, lib_dir) + else: + if os.path.exists(lib_dir): + shutil.rmtree(lib_dir) + # a directory is in the dependencies. we have to copy + # all of /opt/ml/code into the lib dir because the original directory + # was flattened by the SDK training job upload.. + shutil.copytree("/opt/ml/code", lib_dir) + break + + # copy the "src" dir, which includes the previous training job's model and the + # custom inference script, to the output of this training job + copy_tree(src_dir, "/opt/ml/model") + + +if __name__ == "__main__": # pragma: no cover + parser = argparse.ArgumentParser() + parser.add_argument("--inference_script", type=str, default="inference.py") + parser.add_argument("--dependencies", type=str, default=None) + parser.add_argument("--source_dir", type=str, default=None) + parser.add_argument("--model_archive", type=str, default="model.tar.gz") + args, extra = parser.parse_known_args() + repack( + inference_script=args.inference_script, + dependencies=args.dependencies, + source_dir=args.source_dir, + model_archive=args.model_archive, + ) diff --git a/tests/data/pytorch_ddp/mnist_pt.py b/tests/data/pytorch_ddp/mnist_pt.py new file mode 100644 index 0000000000..6c37f9102b --- /dev/null +++ b/tests/data/pytorch_ddp/mnist_pt.py @@ -0,0 +1,246 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import print_function + +import argparse +import os +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.optim.lr_scheduler import StepLR +from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP +import smdistributed.dataparallel.torch.distributed as dist + +dist.init_process_group(backend="nccl") + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout2d(0.25) + self.dropout2 = nn.Dropout2d(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0 and args.rank == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data) * args.world_size, + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + if args.verbose: + print("Batch", batch_idx, "from rank", args.rank) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", + type=int, + default=64, + metavar="N", + help="input batch size for training (default: 64)", + ) + parser.add_argument( + "--test-batch-size", + type=int, + default=1000, + metavar="N", + help="input batch size for testing (default: 1000)", + ) + parser.add_argument( + "--epochs", + type=int, + default=14, + metavar="N", + help="number of epochs to train (default: 14)", + ) + parser.add_argument( + "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)" + ) + parser.add_argument( + "--gamma", + type=float, + default=0.7, + metavar="M", + help="Learning rate step gamma (default: 0.7)", + ) + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument( + "--save-model", action="store_true", default=False, help="For Saving the current Model" + ) + parser.add_argument( + "--verbose", + action="store_true", + default=False, + help="For displaying SM Distributed Data Parallel-specific logs", + ) + parser.add_argument( + "--data-path", + type=str, + default=os.environ["SM_CHANNEL_TRAINING"], + help="Path for downloading the MNIST dataset", + ) + + args = parser.parse_args() + args.world_size = dist.get_world_size() + args.rank = rank = dist.get_rank() + args.local_rank = local_rank = dist.get_local_rank() + args.lr = 1.0 + args.batch_size //= args.world_size // 8 + args.batch_size = max(args.batch_size, 1) + data_path = args.data_path + + if args.verbose: + print( + "Hello from rank", + rank, + "of local_rank", + local_rank, + "in world size of", + args.world_size, + ) + + if not torch.cuda.is_available(): + raise Exception( + "Must run SM Distributed Data Parallel MNIST example on CUDA-capable devices." + ) + + torch.manual_seed(args.seed) + + device = torch.device("cuda") + + if local_rank == 0: + train_dataset = datasets.MNIST( + data_path, + train=True, + download=False, # True sets a dependency on an external site for our tests. + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + else: + time.sleep(8) + train_dataset = datasets.MNIST( + data_path, + train=True, + download=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ) + + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_dataset, num_replicas=args.world_size, rank=rank + ) + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=0, + pin_memory=True, + sampler=train_sampler, + ) + if rank == 0: + test_loader = torch.utils.data.DataLoader( + datasets.MNIST( + data_path, + train=False, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) + + model = DDP(Net().to(device)) + torch.cuda.set_device(local_rank) + model.cuda(local_rank) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + if rank == 0: + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "mnist_cnn.pt") + + +if __name__ == "__main__": + main() diff --git a/tests/integ/test_pytorchddp.py b/tests/integ/test_pytorchddp.py new file mode 100644 index 0000000000..c580fdebc2 --- /dev/null +++ b/tests/integ/test_pytorchddp.py @@ -0,0 +1,53 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import sagemaker.utils +import tests.integ as integ +from sagemaker.pytorch import PyTorch +from tests.integ import timeout +from tests.integ.test_pytorch import _upload_training_data + +pytorchddp_dir = os.path.join(os.path.dirname(__file__), "..", "data", "pytorch_ddp") + + +@pytest.mark.skip( + reason="This test is skipped for now due ML capacity error." + "This test should be re-enabled later." +) +@pytest.mark.skipif( + integ.test_region() not in integ.DATA_PARALLEL_TESTING_REGIONS, + reason="Only allow this test to run in IAD and CMH to limit usage of p3.16xlarge", +) +def test_pytorchddp_pt_mnist( + sagemaker_session, + pytorch_ddp_framework_version, + pytorch_ddp_py_version, +): + job_name = sagemaker.utils.unique_name_from_base("pt-pytorch-ddp") + estimator = PyTorch( + entry_point="mnist_pt.py", + role="SageMakerRole", + source_dir=pytorchddp_dir, + instance_count=2, + instance_type="ml.p3.16xlarge", + sagemaker_session=sagemaker_session, + framework_version=pytorch_ddp_framework_version, + py_version=pytorch_ddp_py_version, + distribution={"pytorchddp": {"enabled": True}}, + ) + + with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 24bb7368a4..018255cf47 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -847,3 +847,65 @@ def test_validate_smdataparallel_args_not_raises(): fw_utils._validate_smdataparallel_args( instance_type, framework_name, framework_version, py_version, distribution ) + + +def test_validate_pytorchddp_not_raises(): + # Case 1: Framework is not PyTorch + fw_utils.validate_pytorch_distribution( + distribution=None, + framework_name="tensorflow", + framework_version="2.9.1", + py_version="py3", + image_uri="custom-container", + ) + # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP + pytorchddp_disabled = {"pytorchddp": {"enabled": False}} + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_disabled, + framework_name="pytorch", + framework_version="1.10", + py_version="py3", + image_uri="custom-container", + ) + # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions + pytorchddp_enabled = {"pytorchddp": {"enabled": True}} + pytorchddp_supported_fw_versions = [ + "1.10", + "1.10.0", + "1.10.2", + "1.11", + "1.11.0", + "1.12", + "1.12.0", + ] + for framework_version in pytorchddp_supported_fw_versions: + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version=framework_version, + py_version="py3", + image_uri="custom-container", + ) + + +def test_validate_pytorchddp_raises(): + pytorchddp_enabled = {"pytorchddp": {"enabled": True}} + # Case 1: Unsupported framework version + with pytest.raises(ValueError): + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version="1.8", + py_version="py3", + image_uri=None, + ) + + # Case 2: Unsupported Py version + with pytest.raises(ValueError): + fw_utils.validate_pytorch_distribution( + distribution=pytorchddp_enabled, + framework_name="pytorch", + framework_version="1.10", + py_version="py2", + image_uri=None, + ) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 8b8541e816..082f699d63 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -56,6 +56,8 @@ "TrialComponentDisplayName": "tc", } +DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} + @pytest.fixture(name="sagemaker_session") def fixture_sagemaker_session(): @@ -97,7 +99,7 @@ def _pytorch_estimator( py_version, instance_type=None, base_job_name=None, - **kwargs + **kwargs, ): return PyTorch( entry_point=SCRIPT_PATH, @@ -108,7 +110,7 @@ def _pytorch_estimator( instance_count=INSTANCE_COUNT, instance_type=instance_type if instance_type else INSTANCE_TYPE, base_job_name=base_job_name, - **kwargs + **kwargs, ) @@ -763,3 +765,38 @@ def test_register_pytorch_model_auto_infer_framework( sagemaker_session.create_model_package_from_containers.assert_called_with( **expected_create_model_package_request ) + + +def test_pytorch_ddp_distribution_configuration( + sagemaker_session, pytorch_ddp_framework_version, pytorch_ddp_py_version +): + test_instance_type = "ml.p4d.24xlarge" + pytorch = _pytorch_estimator( + sagemaker_session, + framework_version=pytorch_ddp_framework_version, + py_version=pytorch_ddp_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + instance_type=test_instance_type, + ) + actual_pytorch_ddp = pytorch._pytorch_distribution_configuration( + distribution=pytorch.distribution + ) + expected_torch_ddp = { + "sagemaker_pytorch_ddp_enabled": True, + "sagemaker_instance_type": test_instance_type, + } + assert actual_pytorch_ddp == expected_torch_ddp + + +def test_pytorch_ddp_distribution_configuration_unsupported(sagemaker_session): + unsupported_framework_version = "1.9.1" + unsupported_py_version = "py2" + with pytest.raises(ValueError) as error: + _pytorch_estimator( + sagemaker_session, + framework_version=unsupported_framework_version, + py_version=unsupported_py_version, + distribution=DISTRIBUTION_PYTORCH_DDP_ENABLED, + ) + assert (f"framework_version {unsupported_framework_version} is not supported") in str(error) + assert (f"py_version {unsupported_py_version} is not supported") in str(error) From 5d8b65c2c0757b3eda2491dc370e29f9c4bf2efa Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 1 Aug 2022 16:25:59 -0700 Subject: [PATCH 149/526] fix: Allow StepCollection added in ConditionStep to be depended on (#3261) Co-authored-by: Dewen Qi --- src/sagemaker/workflow/pipeline.py | 127 ++++++++------- .../unit/sagemaker/workflow/test_pipeline.py | 33 +++- .../sagemaker/workflow/test_pipeline_graph.py | 26 +-- .../workflow/test_step_collections.py | 150 ++++++++++++++++++ 4 files changed, 267 insertions(+), 69 deletions(-) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index f560945752..275d952f81 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -37,48 +37,58 @@ from sagemaker.workflow.pipeline_experiment_config import PipelineExperimentConfig from sagemaker.workflow.parallelism_config import ParallelismConfiguration from sagemaker.workflow.properties import Properties -from sagemaker.workflow.steps import Step +from sagemaker.workflow.steps import Step, StepTypeEnum from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.utilities import list_to_request +_DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig( + ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID +) + -@attr.s class Pipeline(Entity): - """Pipeline for workflow. + """Pipeline for workflow.""" - Attributes: - name (str): The name of the pipeline. - parameters (Sequence[Parameter]): The list of the parameters. - pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set, - the workflow will attempt to create an experiment and trial before - executing the steps. Creation will be skipped if an experiment or a trial with - the same name already exists. By default, pipeline name is used as - experiment name and execution id is used as the trial name. - If set to None, no experiment or trial will be created automatically. - steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps - associated with the pipeline. Any steps that are within the - `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a - pipeline. Of particular note, the workflow service rejects any pipeline definitions that - specify a step in the list of steps of a pipeline and that step in the `if_steps` or - `else_steps` of any `ConditionStep`. - sagemaker_session (sagemaker.session.Session): Session object that manages interactions - with Amazon SageMaker APIs and any other AWS services needed. If not specified, the - pipeline creates one using the default AWS configuration chain. - """ + def __init__( + self, + name: str = "", + parameters: Optional[Sequence[Parameter]] = None, + pipeline_experiment_config: Optional[PipelineExperimentConfig] = _DEFAULT_EXPERIMENT_CFG, + steps: Optional[Sequence[Union[Step, StepCollection]]] = None, + sagemaker_session: Optional[Session] = None, + ): + """Initialize a Pipeline - name: str = attr.ib(factory=str) - parameters: Sequence[Parameter] = attr.ib(factory=list) - pipeline_experiment_config: Optional[PipelineExperimentConfig] = attr.ib( - default=PipelineExperimentConfig( - ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID - ) - ) - steps: Sequence[Union[Step, StepCollection]] = attr.ib(factory=list) - sagemaker_session: Session = attr.ib(factory=Session) + Args: + name (str): The name of the pipeline. + parameters (Sequence[Parameter]): The list of the parameters. + pipeline_experiment_config (Optional[PipelineExperimentConfig]): If set, + the workflow will attempt to create an experiment and trial before + executing the steps. Creation will be skipped if an experiment or a trial with + the same name already exists. By default, pipeline name is used as + experiment name and execution id is used as the trial name. + If set to None, no experiment or trial will be created automatically. + steps (Sequence[Union[Step, StepCollection]]): The list of the non-conditional steps + associated with the pipeline. Any steps that are within the + `if_steps` or `else_steps` of a `ConditionStep` cannot be listed in the steps of a + pipeline. Of particular note, the workflow service rejects any pipeline definitions + that specify a step in the list of steps of a pipeline and that step in the + `if_steps` or `else_steps` of any `ConditionStep`. + sagemaker_session (sagemaker.session.Session): Session object that manages interactions + with Amazon SageMaker APIs and any other AWS services needed. If not specified, the + pipeline creates one using the default AWS configuration chain. + """ + self.name = name + self.parameters = parameters if parameters else [] + self.pipeline_experiment_config = pipeline_experiment_config + self.steps = steps if steps else [] + self.sagemaker_session = sagemaker_session if sagemaker_session else Session() - _version: str = "2020-12-01" - _metadata: Dict[str, Any] = dict() + self._version = "2020-12-01" + self._metadata = dict() + self._step_map = dict() + _generate_step_map(self.steps, self._step_map) def to_request(self) -> RequestType: """Gets the request structure for workflow service calls.""" @@ -193,6 +203,8 @@ def update( Returns: A response dict from the service. """ + self._step_map = dict() + _generate_step_map(self.steps, self._step_map) kwargs = self._create_args(role_arn, description, parallelism_config) return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) @@ -305,23 +317,27 @@ def definition(self) -> str: return json.dumps(request_dict) - def _interpolate_step_collection_name_in_depends_on(self, step_requests: dict): + def _interpolate_step_collection_name_in_depends_on(self, step_requests: list): """Insert step names as per `StepCollection` name in depends_on list Args: - step_requests (dict): The raw step request dict without any interpolation. + step_requests (list): The list of raw step request dicts without any interpolation. """ - step_name_map = {s.name: s for s in self.steps} for step_request in step_requests: - if not step_request.get("DependsOn", None): - continue depends_on = [] - for depend_step_name in step_request["DependsOn"]: - if isinstance(step_name_map[depend_step_name], StepCollection): - depends_on.extend([s.name for s in step_name_map[depend_step_name].steps]) + for depend_step_name in step_request.get("DependsOn", []): + if isinstance(self._step_map[depend_step_name], StepCollection): + depends_on.extend([s.name for s in self._step_map[depend_step_name].steps]) else: depends_on.append(depend_step_name) - step_request["DependsOn"] = depends_on + if depends_on: + step_request["DependsOn"] = depends_on + + if step_request["Type"] == StepTypeEnum.CONDITION.value: + sub_step_requests = ( + step_request["Arguments"]["IfSteps"] + step_request["Arguments"]["ElseSteps"] + ) + self._interpolate_step_collection_name_in_depends_on(sub_step_requests) def format_start_parameters(parameters: Dict[str, Any]) -> List[Dict[str, Any]]: @@ -448,6 +464,20 @@ def update_args(args: Dict[str, Any], **kwargs): args.update({key: value}) +def _generate_step_map( + steps: Sequence[Union[Step, StepCollection]], step_map: dict +) -> Dict[str, Any]: + """Helper method to create a mapping from Step/Step Collection name to itself.""" + for step in steps: + if step.name in step_map: + raise ValueError("Pipeline steps cannot have duplicate names.") + step_map[step.name] = step + if isinstance(step, ConditionStep): + _generate_step_map(step.if_steps + step.else_steps, step_map) + if isinstance(step, StepCollection): + _generate_step_map(step.steps, step_map) + + @attr.s class _PipelineExecution: """Internal class for encapsulating pipeline execution instances. @@ -547,22 +577,11 @@ class PipelineGraph: def __init__(self, steps: Sequence[Union[Step, StepCollection]]): self.step_map = {} - self._generate_step_map(steps) + _generate_step_map(steps, self.step_map) self.adjacency_list = self._initialize_adjacency_list() if self.is_cyclic(): raise ValueError("Cycle detected in pipeline step graph.") - def _generate_step_map(self, steps: Sequence[Union[Step, StepCollection]]): - """Helper method to create a mapping from Step/Step Collection name to itself.""" - for step in steps: - if step.name in self.step_map: - raise ValueError("Pipeline steps cannot have duplicate names.") - self.step_map[step.name] = step - if isinstance(step, ConditionStep): - self._generate_step_map(step.if_steps + step.else_steps) - if isinstance(step, StepCollection): - self._generate_step_map(step.steps) - @classmethod def from_pipeline(cls, pipeline: Pipeline): """Create a PipelineGraph object from the Pipeline object.""" diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index a9e9474013..5cd94dd76a 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -20,6 +20,8 @@ from mock import Mock from sagemaker import s3 +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline, PipelineGraph @@ -28,6 +30,7 @@ PipelineExperimentConfig, PipelineExperimentConfigProperties, ) +from sagemaker.workflow.step_collections import StepCollection from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep @@ -78,7 +81,7 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + steps=_generate_large_pipeline_steps(parameter), sagemaker_session=sagemaker_session_mock, ) @@ -105,6 +108,25 @@ def test_pipeline_update(sagemaker_session_mock, role_arn): sagemaker_session=sagemaker_session_mock, ) pipeline.update(role_arn=role_arn) + assert len(json.loads(pipeline.definition())["Steps"]) == 0 + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( + PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn + ) + + step1 = CustomStep(name="MyStep1") + step2 = CustomStep(name="MyStep2", input_data=step1.properties) + step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2]) + cond_step = ConditionStep( + name="MyConditionStep", + depends_on=[], + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[step_collection], + else_steps=[], + ) + step3 = CustomStep(name="MyStep3", depends_on=[step_collection]) + pipeline.steps = [cond_step, step3] + pipeline.update(role_arn=role_arn) + assert len(json.loads(pipeline.definition())["Steps"]) > 0 assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -132,7 +154,7 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn): pipeline = Pipeline( name="MyPipeline", parameters=[parameter], - steps=[CustomStep(name="MyStep", input_data=parameter)] * 2000, + steps=_generate_large_pipeline_steps(parameter), sagemaker_session=sagemaker_session_mock, ) @@ -437,3 +459,10 @@ def test_pipeline_execution_basics(sagemaker_session_mock): PipelineExecutionArn="my:arn" ) assert len(steps) == 1 + + +def _generate_large_pipeline_steps(input_data: object): + steps = [] + for i in range(2000): + steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data)) + return steps diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py index b7d69e617a..003dd8d048 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_graph.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -45,10 +45,10 @@ def role_arn(): def test_pipeline_duplicate_step_name(sagemaker_session_mock): step1 = CustomStep(name="foo") step2 = CustomStep(name="foo") - pipeline = Pipeline( - name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", steps=[step1, step2], sagemaker_session=sagemaker_session_mock + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) @@ -61,12 +61,12 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): condition_step = ConditionStep( name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2] ) - pipeline = Pipeline( - name="MyPipeline", - steps=[custom_step, condition_step], - sagemaker_session=sagemaker_session_mock, - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, condition_step], + sagemaker_session=sagemaker_session_mock, + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) @@ -74,12 +74,12 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock): custom_step = CustomStep(name="foo-1") custom_step_collection = CustomStepCollection(name="foo", depends_on=[custom_step]) - pipeline = Pipeline( - name="MyPipeline", - steps=[custom_step, custom_step_collection], - sagemaker_session=sagemaker_session_mock, - ) with pytest.raises(ValueError) as error: + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step, custom_step_collection], + sagemaker_session=sagemaker_session_mock, + ) PipelineGraph.from_pipeline(pipeline) assert "Pipeline steps cannot have duplicate names." in str(error.value) diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index d3b2a19fe3..d3d1ab022b 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -20,6 +20,8 @@ import pytest from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ConditionEquals from sagemaker.workflow.model_step import ( ModelStep, _CREATE_MODEL_NAME_BASE, @@ -360,6 +362,154 @@ def test_step_collection_is_depended_on(pipeline_session, sagemaker_session): ) +def test_step_collection_in_condition_branch_is_depended_on(pipeline_session, sagemaker_session): + custom_step1 = CustomStep(name="MyStep1") + + # Define a step collection which will be inserted into the ConditionStep + model_name = "MyModel" + model = Model( + name=model_name, + image_uri=IMAGE_URI, + model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"), + sagemaker_session=pipeline_session, + entry_point=f"{DATA_DIR}/dummy_script.py", + source_dir=f"{DATA_DIR}", + role=ROLE, + ) + step_args = model.create( + instance_type="c4.4xlarge", + accelerator_type="ml.eia1.medium", + ) + model_step_name = "MyModelStep" + model_step = ModelStep( + name=model_step_name, + step_args=step_args, + ) + + # Define another step collection which will be inserted into the ConditionStep + # This StepCollection object depends on a StepCollection object in the ConditionStep + # And a normal step outside ConditionStep + model.sagemaker_session = sagemaker_session + register_model_name = "RegisterModelStep" + register_model = RegisterModel( + name=register_model_name, + model=model, + model_data="s3://", + content_types=["content_type"], + response_types=["response_type"], + inference_instances=["inference_instance"], + transform_instances=["transform_instance"], + model_package_group_name="mpg", + depends_on=["MyStep1", model_step], + ) + + # StepCollection objects are depended on by a normal step in the ConditionStep + custom_step2 = CustomStep( + name="MyStep2", depends_on=["MyStep1", model_step, register_model_name] + ) + # StepCollection objects are depended on by a normal step outside the ConditionStep + custom_step3 = CustomStep( + name="MyStep3", depends_on=[custom_step1, model_step_name, register_model] + ) + + cond_step = ConditionStep( + name="CondStep", + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[], + else_steps=[model_step, register_model, custom_step2], + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[cond_step, custom_step1, custom_step3], + ) + step_list = json.loads(pipeline.definition())["Steps"] + assert len(step_list) == 3 + for step in step_list: + if step["Name"] == "MyStep1": + assert "DependsOn" not in step + elif step["Name"] == "CondStep": + assert not step["Arguments"]["IfSteps"] + for sub_step in step["Arguments"]["ElseSteps"]: + if sub_step["Name"] == f"{model_name}-RepackModel": + assert set(sub_step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + } + if sub_step["Name"] == "MyStep2": + assert set(sub_step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", + } + else: + assert set(step["DependsOn"]) == { + "MyStep1", + f"{model_step_name}-{_REPACK_MODEL_NAME_BASE}-{model_name}", + f"{model_step_name}-{_CREATE_MODEL_NAME_BASE}", + f"{model_name}-RepackModel", + f"{register_model_name}-RegisterModel", + } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + { + "CondStep": ["MyModel-RepackModel", "MyModelStep-RepackModel-MyModel", "MyStep2"], + "MyStep1": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyStep2": [], + "MyStep3": [], + "MyModelStep-RepackModel-MyModel": ["MyModelStep-CreateModel"], + "MyModelStep-CreateModel": ["MyStep2", "MyStep3", "MyModel-RepackModel"], + "MyModel-RepackModel": [], + "RegisterModelStep-RegisterModel": ["MyStep2", "MyStep3"], + } + ) + + +def test_condition_step_depends_on_step_collection(): + step1 = CustomStep(name="MyStep1") + step2 = CustomStep(name="MyStep2", input_data=step1.properties) + step_collection = StepCollection(name="MyStepCollection", steps=[step1, step2]) + cond_step = ConditionStep( + name="MyConditionStep", + depends_on=[step_collection], + conditions=[ConditionEquals(left=2, right=1)], + if_steps=[], + else_steps=[], + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[step_collection, cond_step], + ) + step_list = json.loads(pipeline.definition())["Steps"] + assert len(step_list) == 3 + for step in step_list: + if step["Name"] != "MyConditionStep": + continue + assert step == { + "Name": "MyConditionStep", + "Type": "Condition", + "DependsOn": ["MyStep1", "MyStep2"], + "Arguments": { + "Conditions": [ + { + "Type": "Equals", + "LeftValue": 2, + "RightValue": 1, + }, + ], + "IfSteps": [], + "ElseSteps": [], + }, + } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list + assert ordered(adjacency_list) == ordered( + [("MyConditionStep", []), ("MyStep1", ["MyStep2"]), ("MyStep2", ["MyConditionStep"])] + ) + + def test_register_model(estimator, model_metrics, drift_check_baselines): model_data = f"s3://{BUCKET}/model.tar.gz" register_model = RegisterModel( From e368be01dc5c6c4fa8f81be8267f9d2ab8535e1d Mon Sep 17 00:00:00 2001 From: yongyanrao <75150929+yongyanrao@users.noreply.github.com> Date: Mon, 1 Aug 2022 20:54:08 -0500 Subject: [PATCH 150/526] change: add a check to prevent launching a modelparallel job on CPU only instances (#3262) Co-authored-by: Yongyan Rao --- src/sagemaker/fw_utils.py | 46 ++++++++++++++++++++++++++++++ src/sagemaker/pytorch/estimator.py | 7 +++++ tests/unit/test_fw_utils.py | 34 ++++++++++++++++++++++ 3 files changed, 87 insertions(+) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index ef99454a45..613bbd3742 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -835,6 +835,52 @@ def validate_pytorch_distribution( raise ValueError(err_msg) +def validate_distribution_instance(sagemaker_session, distribution, instance_type): + """Check to prevent launching a modelparallel job on CPU only instances. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + distribution (dict): A dictionary with information to enable distributed training. + distribution = { + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + ... + }, + }, + }, + ... + } + instance_type (str): A string representing the type of training instance selected. + + Raises: + ValueError: when modelparallel is enabled, if the instance_type does not support GPU. + """ + if "smdistributed" not in distribution: + # Distribution strategy other than smdistributed is selected + return + + if "modelparallel" not in distribution["smdistributed"]: + # Strategy other than modelparallel is selected + return + + if not distribution["smdistributed"]["modelparallel"]["enabled"]: + # Strategy modelparallel is not enabled + return + + instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( + InstanceTypes=[f"{instance_type}"] + ) + if "GpuInfo" not in instance_desc["InstanceTypes"][0]: + raise ValueError( + f"modelparallel only runs on GPU-enabled instances. " + f"{instance_type} does not support GPU." + ) + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 153d4656d4..622e79084c 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + validate_distribution_instance, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -220,6 +221,12 @@ def __init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) if distribution is not None: + instance_type = self._get_instance_type() + # remove "ml." prefix + if instance_type[:3] == "ml.": + instance_type = instance_type[3:] + validate_distribution_instance(self.sagemaker_session, distribution, instance_type) + distribution = validate_distribution( distribution, self.instance_groups, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 018255cf47..5ecf196731 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -49,6 +49,15 @@ def sagemaker_session(): session_mock.sagemaker_client.describe_training_job = Mock( return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} ) + session_mock.boto_session.client("ec2").describe_instance_types = Mock( + return_value={ + "InstanceTypes": [ + { + "CpuInfo": {}, + }, + ], + } + ) return session_mock @@ -733,6 +742,31 @@ def test_validate_smdistributed_not_raises(): ) +def test_validate_distribution_instance_no_smdistributed(sagemaker_session): + distribution = {} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_no_modelparallel(sagemaker_session): + distribution = {"smdistributed": {}} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session): + distribution = {"smdistributed": {"modelparallel": {"enabled": False}}} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_raise(sagemaker_session): + distribution = {"smdistributed": {"modelparallel": {"enabled": True}}} + instance_type = "mock_type" + with pytest.raises(ValueError): + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + def test_validate_smdistributed_raises(): bad_args = [ {"smdistributed": "dummy"}, From 23d5885887cb285aca42d870770d4c663daf416f Mon Sep 17 00:00:00 2001 From: keerthanvasist Date: Tue, 2 Aug 2022 12:09:31 -0700 Subject: [PATCH 151/526] fix: Two letter language code must be supported (#3258) --- src/sagemaker/clarify.py | 67 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 873a87ca57..6590d30514 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -512,68 +512,131 @@ class TextConfig: _SUPPORTED_GRANULARITIES = ["token", "sentence", "paragraph"] _SUPPORTED_LANGUAGES = [ "chinese", + "zh", "danish", + "da", "dutch", + "nl", "english", + "en", "french", + "fr", "german", + "de", "greek", + "el", "italian", + "it", "japanese", + "ja", "lithuanian", + "lt", "multi-language", + "xx", "norwegian bokmål", + "nb", "polish", + "pl", "portuguese", + "pt", "romanian", + "ro", "russian", + "ru", "spanish", + "es", "afrikaans", + "af", "albanian", + "sq", "arabic", + "ar", "armenian", + "hy", "basque", + "eu", "bengali", + "bn", "bulgarian", + "bg", "catalan", + "ca", "croatian", + "hr", "czech", + "cs", "estonian", + "et", "finnish", + "fi", "gujarati", + "gu", "hebrew", + "he", "hindi", + "hi", "hungarian", + "hu", "icelandic", + "is", "indonesian", + "id", "irish", + "ga", "kannada", + "kn", "kyrgyz", + "ky", "latvian", + "lv", "ligurian", + "lij", "luxembourgish", + "lb", "macedonian", + "mk", "malayalam", + "ml", "marathi", + "mr", "nepali", + "ne", "persian", + "fa", "sanskrit", + "sa", "serbian", + "sr", "setswana", + "tn", "sinhala", + "si", "slovak", + "sk", "slovenian", + "sl", "swedish", + "sv", "tagalog", + "tl", "tamil", + "ta", "tatar", + "tt", "telugu", + "te", "thai", + "th", "turkish", + "tr", "ukrainian", + "uk", "urdu", + "ur", "vietnamese", + "vi", "yoruba", + "yo", ] def __init__( @@ -602,8 +665,8 @@ def __init__( ``"persian"``, ``"sanskrit"``, ``"serbian"``, ``"setswana"``, ``"sinhala"``, ``"slovak"``, ``"slovenian"``, ``"swedish"``, ``"tagalog"``, ``"tamil"``, ``"tatar"``, ``"telugu"``, ``"thai"``, ``"turkish"``, ``"ukrainian"``, ``"urdu"``, - ``"vietnamese"``, ``"yoruba"``. - Use ``"multi-language"`` for a mix of multiple languages. + ``"vietnamese"``, ``"yoruba"``. Use "multi-language" for a mix of multiple + languages. The corresponding two-letter ISO codes are also accepted. Raises: ValueError: when ``granularity`` is not in list of supported values From 8a0cf5c42fe41c054cc07b32d7a3bacf7ad9643f Mon Sep 17 00:00:00 2001 From: Miyoung Date: Tue, 2 Aug 2022 13:13:57 -0700 Subject: [PATCH 152/526] documentation: smdistributed libraries currency updates (#3266) --- doc/api/training/sdp_versions/latest.rst | 4 +- .../smd_data_parallel_change_log.rst | 48 ++++++++++++++++--- .../smd_model_parallel_change_log.rst | 17 ++++++- 3 files changed, 58 insertions(+), 11 deletions(-) diff --git a/doc/api/training/sdp_versions/latest.rst b/doc/api/training/sdp_versions/latest.rst index 7ca0061aa9..c3fcc5f78e 100644 --- a/doc/api/training/sdp_versions/latest.rst +++ b/doc/api/training/sdp_versions/latest.rst @@ -26,8 +26,8 @@ depending on the version of the library you use. `_ for more information. -Version 1.4.0, 1.4.1 (Latest) -============================= +Version 1.4.0, 1.4.1, 1.5.0 (Latest) +==================================== .. toctree:: :maxdepth: 1 diff --git a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst index 289074255d..05eb7220e0 100644 --- a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst +++ b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst @@ -7,9 +7,46 @@ Release Notes New features, bug fixes, and improvements are regularly made to the SageMaker distributed data parallel library. -SageMaker Distributed Data Parallel 1.4.1 Release Notes +SageMaker Distributed Data Parallel 1.5.0 Release Notes ======================================================= +*Date: Jul. 26. 2022* + +**Currency Updates** + +* Added support for PyTorch 1.12.0. + +**Bug Fixes** + +* Improved stability for long-running training jobs. + + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- PyTorch 1.12.0 DLC + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.0-gpu-py38-cu113-ubuntu20.04-sagemaker + +Binary file of this version of the library for `custom container +`_ users: + + .. code:: + + https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.0/cu113/2022-07-01/smdistributed_dataparallel-1.5.0-cp38-cp38-linux_x86_64.whl + + +---- + +Release History +=============== + +SageMaker Distributed Data Parallel 1.4.1 Release Notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + *Date: May. 3. 2022* **Currency Updates** @@ -18,7 +55,9 @@ SageMaker Distributed Data Parallel 1.4.1 Release Notes **Known Issues** -* The library currently does not support the PyTorch sub-process groups API (torch.distributed.new_group (https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group)). +* The library currently does not support the PyTorch sub-process groups API + (`torch.distributed.new_group + `_). **Migration to AWS Deep Learning Containers** @@ -38,11 +77,6 @@ Binary file of this version of the library for custom container users: https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.11.0/cu113/2022-04-14/smdistributed_dataparallel-1.4.1-cp38-cp38-linux_x86_64.whl ----- - -Release History -=============== - SageMaker Distributed Data Parallel 1.4.0 Release Notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst index 12ed10049a..d6b93c3de3 100644 --- a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst +++ b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst @@ -35,19 +35,32 @@ The following new features are added for PyTorch. This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): -- PyTorch 1.11.0 DLC +- DLC for PyTorch 1.11.0 .. code:: 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.11.0-gpu-py38-cu113-ubuntu20.04-sagemaker -Binary file of this version of the library for custom container users: +- DLC for PyTorch 1.12.0 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.0-gpu-py38-cu113-ubuntu20.04-sagemaker + +Binary file of this version of the library for `custom container +`_ users: + +- For PyTorch 1.11.0 .. code:: https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.11.0/build-artifacts/2022-07-11-19-23/smdistributed_modelparallel-1.10.0-cp38-cp38-linux_x86_64.whl +- For PyTorch 1.12.0 + + .. code:: + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.0/build-artifacts/2022-07-11-19-23/smdistributed_modelparallel-1.10.0-cp38-cp38-linux_x86_64.whl ---- From 3b61415d5189f83794db428b498d61d9f5bc7ad6 Mon Sep 17 00:00:00 2001 From: Sara <40234180+ZiweiG@users.noreply.github.com> Date: Tue, 2 Aug 2022 14:23:31 -0700 Subject: [PATCH 153/526] feature: add warnings for xgboost specific rules in debugger rules (#3255) --- src/sagemaker/estimator.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index dee102999b..9d0c30ff27 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -792,6 +792,8 @@ def _prepare_rules(self): if self.rules is not None: for rule in self.rules: if isinstance(rule, Rule): + # Add check for xgboost rules + self._check_debugger_rule(rule) self.debugger_rules.append(rule) elif isinstance(rule, ProfilerRule): self.profiler_rules.append(rule) @@ -801,6 +803,16 @@ def _prepare_rules(self): + "and sagemaker.debugger.ProfilerRule" ) + def _check_debugger_rule(self, rule): + """Add warning for incorrectly used xgboost rules.""" + _xgboost_specific_rules = ["FeatureImportanceOverweight", "TreeDepth"] + if rule.name in _xgboost_specific_rules: + logger.warning( + "TreeDepth and FeatureImportanceOverweight rules are valid " + "only for the XGBoost algorithm. Please make sure this estimator " + "is used for XGBoost algorithm. " + ) + def _prepare_debugger_for_training(self): """Prepare debugger rules and debugger configs for training.""" if self.debugger_rules and self.debugger_hook_config is None: From efc5275bc3126dad0756f138361f77028b0f91c9 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 4 Aug 2022 02:09:28 +0000 Subject: [PATCH 154/526] prepare release v2.102.0 --- CHANGELOG.md | 20 ++++++++++++++++++++ VERSION | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b40005e7fa..5a5e9ed081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## v2.102.0 (2022-08-04) + +### Features + + * add warnings for xgboost specific rules in debugger rules + * Add PyTorch DDP distribution support + * Add test for profiler enablement with debugger_hook false + +### Bug Fixes and Other Changes + + * Two letter language code must be supported + * add a check to prevent launching a modelparallel job on CPU only instances + * Allow StepCollection added in ConditionStep to be depended on + * Add PipelineVariable annotation in framework models + * skip managed spot training mxnet nb + +### Documentation Changes + + * smdistributed libraries currency updates + ## v2.101.1 (2022-07-28) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index bd7dda9e04..fcf29bedc5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.101.2.dev0 +2.102.0 From 2070b7e8b647c40fe5cd4177053aeedacc2a0809 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 4 Aug 2022 02:09:29 +0000 Subject: [PATCH 155/526] update development version to v2.102.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index fcf29bedc5..88a21ca390 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.102.0 +2.102.1.dev0 From ead5c4ef8526c48f9610b3edf3c5a439e633d51c Mon Sep 17 00:00:00 2001 From: Sai Parthasarathy Miduthuri <54188298+saimidu@users.noreply.github.com> Date: Thu, 4 Aug 2022 09:46:30 -0700 Subject: [PATCH 156/526] fix: Link PyTorch 1.11 to 1.11.0 (#3259) --- src/sagemaker/image_uri_config/pytorch.json | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 1bea5ffc30..5a2fb90202 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -66,7 +66,8 @@ "1.7": "1.7.1", "1.8": "1.8.1", "1.9": "1.9.1", - "1.10": "1.10.0" + "1.10": "1.10.0", + "1.11": "1.11.0" }, "versions": { "0.4.0": { @@ -635,7 +636,8 @@ "1.7": "1.7.1", "1.8": "1.8.1", "1.9": "1.9.1", - "1.10": "1.10.0" + "1.10": "1.10.0", + "1.11": "1.11.0" }, "versions": { "0.4.0": { From 20ae406faa908aa5368bb4d83ebc08998ab2e454 Mon Sep 17 00:00:00 2001 From: Alexander Shirkov <10080307+gradientsky@users.noreply.github.com> Date: Thu, 4 Aug 2022 09:49:21 -0700 Subject: [PATCH 157/526] feature: AutoGluon 0.4.3 and 0.5.2 image_uris (#3274) --- src/sagemaker/image_uri_config/autogluon.json | 132 +++++++++++++++++- .../sagemaker/image_uris/test_autogluon.py | 2 +- 2 files changed, 131 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 505f1d1f7e..3cc488c55d 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -3,7 +3,8 @@ "processors": ["cpu", "gpu"], "version_aliases": { "0.3": "0.3.2", - "0.4": "0.4.2" + "0.4": "0.4.3", + "0.5": "0.5.2" }, "versions": { "0.3.1": { @@ -125,13 +126,74 @@ }, "repository": "autogluon-training", "py_versions": ["py38"] + }, + "0.4.3": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-training", + "py_versions": ["py38"] + }, + "0.5.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-training", + "py_versions": ["py38"] } } }, "inference": { "version_aliases": { "0.3": "0.3.2", - "0.4": "0.4.2" + "0.4": "0.4.3", + "0.5": "0.5.2" }, "versions": { "0.3.1": { @@ -265,6 +327,72 @@ "repository": "autogluon-inference", "processors": ["cpu", "gpu"], "py_versions": ["py38"] + }, + "0.4.3": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-inference", + "processors": ["cpu", "gpu"], + "py_versions": ["py38"] + }, + "0.5.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "autogluon-inference", + "processors": ["cpu", "gpu"], + "py_versions": ["py38"] } } } diff --git a/tests/unit/sagemaker/image_uris/test_autogluon.py b/tests/unit/sagemaker/image_uris/test_autogluon.py index 8ce12ec670..7f7aea2850 100644 --- a/tests/unit/sagemaker/image_uris/test_autogluon.py +++ b/tests/unit/sagemaker/image_uris/test_autogluon.py @@ -42,7 +42,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884", } -VERSIONS = ["0.3.1", "0.3.2", "0.4.0", "0.4.2", "0.3", "0.4"] +VERSIONS = ["0.3.1", "0.3.2", "0.4.0", "0.4.2", "0.4.3", "0.3", "0.4", "0.5.2", "0.5"] SCOPES = ["training", "inference"] PROCESSORS = ["cpu", "gpu"] From cba4c20e64a5926165333c65814c032566013bfc Mon Sep 17 00:00:00 2001 From: Robert Dargavel Smith Date: Thu, 4 Aug 2022 18:53:52 +0100 Subject: [PATCH 158/526] fix: Add gpu capability to local (#3136) --- src/sagemaker/local/image.py | 4 +++- tests/unit/test_image.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 38d963bfda..1fbc93fca0 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -759,7 +759,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes # for GPU support pass in nvidia as the runtime, this is equivalent # to setting --runtime=nvidia in the docker commandline. if self.instance_type == "local_gpu": - host_config["runtime"] = "nvidia" + host_config["deploy"] = { + "resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}} + } if command == "serve": serving_port = ( diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 020f648834..e7bde99610 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -574,8 +574,10 @@ def test_container_has_gpu_support(tmpdir, sagemaker_session): ) docker_host = sagemaker_container._create_docker_host("host-1", {}, set(), "train", []) - assert "runtime" in docker_host - assert docker_host["runtime"] == "nvidia" + assert "deploy" in docker_host + assert docker_host["deploy"] == { + "resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}} + } def test_container_does_not_enable_nvidia_docker_for_cpu_containers(sagemaker_session): From 2af5d430c3192df5cbea80546d61c5fb893e5750 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Thu, 4 Aug 2022 11:45:17 -0700 Subject: [PATCH 159/526] fix: Revert "change: add a check to prevent launching a modelparallel job on CPU only instances" (#3280) --- src/sagemaker/fw_utils.py | 46 ------------------------------ src/sagemaker/pytorch/estimator.py | 7 ----- tests/unit/test_fw_utils.py | 34 ---------------------- 3 files changed, 87 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 613bbd3742..ef99454a45 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -835,52 +835,6 @@ def validate_pytorch_distribution( raise ValueError(err_msg) -def validate_distribution_instance(sagemaker_session, distribution, instance_type): - """Check to prevent launching a modelparallel job on CPU only instances. - - Args: - sagemaker_session (sagemaker.session.Session): Session object which - manages interactions with Amazon SageMaker APIs and any other - AWS services needed. - distribution (dict): A dictionary with information to enable distributed training. - distribution = { - "smdistributed": { - "modelparallel": { - "enabled": True, - "parameters": { - ... - }, - }, - }, - ... - } - instance_type (str): A string representing the type of training instance selected. - - Raises: - ValueError: when modelparallel is enabled, if the instance_type does not support GPU. - """ - if "smdistributed" not in distribution: - # Distribution strategy other than smdistributed is selected - return - - if "modelparallel" not in distribution["smdistributed"]: - # Strategy other than modelparallel is selected - return - - if not distribution["smdistributed"]["modelparallel"]["enabled"]: - # Strategy modelparallel is not enabled - return - - instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( - InstanceTypes=[f"{instance_type}"] - ) - if "GpuInfo" not in instance_desc["InstanceTypes"][0]: - raise ValueError( - f"modelparallel only runs on GPU-enabled instances. " - f"{instance_type} does not support GPU." - ) - - def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 622e79084c..153d4656d4 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,7 +25,6 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, - validate_distribution_instance, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -221,12 +220,6 @@ def __init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) if distribution is not None: - instance_type = self._get_instance_type() - # remove "ml." prefix - if instance_type[:3] == "ml.": - instance_type = instance_type[3:] - validate_distribution_instance(self.sagemaker_session, distribution, instance_type) - distribution = validate_distribution( distribution, self.instance_groups, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 5ecf196731..018255cf47 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -49,15 +49,6 @@ def sagemaker_session(): session_mock.sagemaker_client.describe_training_job = Mock( return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} ) - session_mock.boto_session.client("ec2").describe_instance_types = Mock( - return_value={ - "InstanceTypes": [ - { - "CpuInfo": {}, - }, - ], - } - ) return session_mock @@ -742,31 +733,6 @@ def test_validate_smdistributed_not_raises(): ) -def test_validate_distribution_instance_no_smdistributed(sagemaker_session): - distribution = {} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_no_modelparallel(sagemaker_session): - distribution = {"smdistributed": {}} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session): - distribution = {"smdistributed": {"modelparallel": {"enabled": False}}} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_raise(sagemaker_session): - distribution = {"smdistributed": {"modelparallel": {"enabled": True}}} - instance_type = "mock_type" - with pytest.raises(ValueError): - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - def test_validate_smdistributed_raises(): bad_args = [ {"smdistributed": "dummy"}, From cea71bb4ad57d6b54d7764a07eabbaaf94a5fd48 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 5 Aug 2022 19:43:44 +0000 Subject: [PATCH 160/526] prepare release v2.103.0 --- CHANGELOG.md | 12 ++++++++++++ VERSION | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a5e9ed081..878fc2a6e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## v2.103.0 (2022-08-05) + +### Features + + * AutoGluon 0.4.3 and 0.5.2 image_uris + +### Bug Fixes and Other Changes + + * Revert "change: add a check to prevent launching a modelparallel job on CPU only instances" + * Add gpu capability to local + * Link PyTorch 1.11 to 1.11.0 + ## v2.102.0 (2022-08-04) ### Features diff --git a/VERSION b/VERSION index 88a21ca390..6e593dc5af 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.102.1.dev0 +2.103.0 From 2106d06a835a960713081cc842185894bc2fc6c2 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 5 Aug 2022 19:43:45 +0000 Subject: [PATCH 161/526] update development version to v2.103.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 6e593dc5af..ef1a140e63 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.103.0 +2.103.1.dev0 From 952b052e5a9ed4e253f2b4376638795b2f4a4b1f Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Fri, 5 Aug 2022 19:47:25 -0700 Subject: [PATCH 162/526] fix: Update localmode code to decode urllib response as UTF8 (#3284) --- src/sagemaker/local/entities.py | 2 +- tests/unit/test_local_entities.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 408385739d..3b8da1b46b 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -314,7 +314,7 @@ def start(self, input_data, output_data, transform_resources, **kwargs): endpoint_url = "http://%s:%d/execution-parameters" % (get_docker_host(), serving_port) response, code = _perform_request(endpoint_url) if code == 200: - execution_parameters = json.loads(response.read()) + execution_parameters = json.loads(response.data.decode("utf-8")) # MaxConcurrentTransforms is ignored because we currently only support 1 for setting in ("BatchStrategy", "MaxPayloadInMB"): if setting not in kwargs and setting in execution_parameters: diff --git a/tests/unit/test_local_entities.py b/tests/unit/test_local_entities.py index f7a56959db..6b62cd786b 100644 --- a/tests/unit/test_local_entities.py +++ b/tests/unit/test_local_entities.py @@ -106,7 +106,7 @@ def test_start_local_transform_job(_perform_batch_inference, _perform_request, l response = Mock() _perform_request.return_value = (response, 200) - response.read.return_value = '{"BatchStrategy": "SingleRecord"}' + response.data = '{"BatchStrategy": "SingleRecord"}'.encode("UTF-8") local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model" local_transform_job.start(input_data, output_data, transform_resources, Environment={}) @@ -176,9 +176,9 @@ def test_start_local_transform_job_from_remote_docker_host( output_data = {} transform_resources = {"InstanceType": "local"} m_get_docker_host.return_value = "some_host" - perform_request_mock = Mock() - m_perform_request.return_value = (perform_request_mock, 200) - perform_request_mock.read.return_value = '{"BatchStrategy": "SingleRecord"}' + response = Mock() + m_perform_request.return_value = (response, 200) + response.data = '{"BatchStrategy": "SingleRecord"}'.encode("UTF-8") local_transform_job.primary_container["ModelDataUrl"] = "file:///some/model" local_transform_job.start(input_data, output_data, transform_resources, Environment={}) endpoints = [ From 875d0cb4dc4d44756f59356b14a086d8d8e44f4f Mon Sep 17 00:00:00 2001 From: Zhankui Lu Date: Mon, 15 Aug 2022 10:37:22 -0700 Subject: [PATCH 163/526] Allow users to customize trial component display names for pipeline launched jobs (#3230) Co-authored-by: Zhankui Lu --- ...azon_sagemaker_model_building_pipeline.rst | 2 + .../sagemaker.workflow.pipelines.rst | 2 +- src/sagemaker/estimator.py | 6 + src/sagemaker/processing.py | 19 ++- src/sagemaker/transformer.py | 6 + src/sagemaker/workflow/execution_variables.py | 4 + src/sagemaker/workflow/steps.py | 20 +++- .../workflow/test_processing_step.py | 43 ++++++- .../sagemaker/workflow/test_training_step.py | 52 +++++++- .../sagemaker/workflow/test_transform_step.py | 111 +++++++++++++++--- 10 files changed, 238 insertions(+), 27 deletions(-) diff --git a/doc/amazon_sagemaker_model_building_pipeline.rst b/doc/amazon_sagemaker_model_building_pipeline.rst index 9dfdf01d08..b85a9d9251 100644 --- a/doc/amazon_sagemaker_model_building_pipeline.rst +++ b/doc/amazon_sagemaker_model_building_pipeline.rst @@ -741,6 +741,8 @@ There are a number of properties for a pipeline execution that can only be resol - :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_EXECUTION_ARN`: The execution ARN for an execution. - :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_NAME`: The name of the pipeline. - :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PIPELINE_ARN`: The ARN of the pipeline. +- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.TRAINING_JOB_NAME`: The name of the training job launched by the training step. +- :class:`sagemaker.workflow.execution_variables.ExecutionVariables.PROCESSING_JOB_NAME`: The name of the processing job launched by the processing step. You can use these execution variables as you see fit. The following example uses the :code:`START_DATETIME` execution variable to construct a processing output path: diff --git a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst index c8b68c0e6f..47e14b5e85 100644 --- a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst +++ b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst @@ -52,7 +52,7 @@ Execution Variables .. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariable .. autoclass:: sagemaker.workflow.execution_variables.ExecutionVariables - :members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN + :members: START_DATETIME, CURRENT_DATETIME, PIPELINE_EXECUTION_ID, PIPELINE_EXECUTION_ARN, PIPELINE_NAME, PIPELINE_ARN, TRAINING_JOB_NAME, PROCESSING_JOB_NAME Functions --------- diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 9d0c30ff27..715c329f47 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -1025,6 +1025,12 @@ def fit( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance + is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. + However, the value of `TrialComponentDisplayName` is honored for display in Studio. + Returns: + None or pipeline step arguments in case the Estimator instance is built with + :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ self._prepare_for_training(job_name=job_name) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 8da6e04768..9a1d8bd431 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -173,9 +173,14 @@ def run( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * Both `ExperimentName` and `TrialName` will be ignored if the Processor instance + is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. + However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). - + Returns: + None or pipeline step arguments in case the Processor instance is built with + :class:`~sagemaker.workflow.pipeline_context.PipelineSession` Raises: ValueError: if ``logs`` is True but ``wait`` is False. """ @@ -543,8 +548,14 @@ def run( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * Both `ExperimentName` and `TrialName` will be ignored if the Processor instance + is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. + However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + Returns: + None or pipeline step arguments in case the Processor instance is built with + :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ normalized_inputs, normalized_outputs = self._normalize_args( job_name=job_name, @@ -1601,8 +1612,14 @@ def run( # type: ignore[override] * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * Both `ExperimentName` and `TrialName` will be ignored if the Processor instance + is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. + However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + Returns: + None or pipeline step arguments in case the Processor instance is built with + :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( code, source_dir, dependencies, git_config, job_name, inputs diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index dbe54c8d57..6df56ad154 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -186,6 +186,9 @@ def transform( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * Both `ExperimentName` and `TrialName` will be ignored if the Transformer instance + is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`. + However, the value of `TrialComponentDisplayName` is honored for display in Studio. model_client_config (dict[str, str]): Model configuration. Dictionary contains two optional keys, 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. @@ -194,6 +197,9 @@ def transform( (default: ``True``). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is ``True`` (default: ``True``). + Returns: + None or pipeline step arguments in case the Transformer instance is built with + :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ local_mode = self.sagemaker_session.local_mode if not local_mode and not is_pipeline_variable(data) and not data.startswith("s3://"): diff --git a/src/sagemaker/workflow/execution_variables.py b/src/sagemaker/workflow/execution_variables.py index 516efb784e..59ad1733ad 100644 --- a/src/sagemaker/workflow/execution_variables.py +++ b/src/sagemaker/workflow/execution_variables.py @@ -58,6 +58,8 @@ class ExecutionVariables: - ExecutionVariables.PIPELINE_ARN - ExecutionVariables.PIPELINE_EXECUTION_ID - ExecutionVariables.PIPELINE_EXECUTION_ARN + - ExecutionVariables.TRAINING_JOB_NAME + - ExecutionVariables.PROCESSING_JOB_NAME """ START_DATETIME = ExecutionVariable("StartDateTime") @@ -66,3 +68,5 @@ class ExecutionVariables: PIPELINE_ARN = ExecutionVariable("PipelineArn") PIPELINE_EXECUTION_ID = ExecutionVariable("PipelineExecutionId") PIPELINE_EXECUTION_ARN = ExecutionVariable("PipelineExecutionArn") + TRAINING_JOB_NAME = ExecutionVariable("TrainingJobName") + PROCESSING_JOB_NAME = ExecutionVariable("ProcessingJobName") diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index d73a899084..e979657bd4 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -223,6 +223,18 @@ def _get_step_name_from_str( return step_map[str_input].steps[-1].name return str_input + @staticmethod + def _trim_experiment_config(request_dict: Dict): + """For job steps, trim the experiment config to keep the trial component display name.""" + if request_dict.get("ExperimentConfig", {}).get("TrialComponentDisplayName"): + request_dict["ExperimentConfig"] = { + "TrialComponentDisplayName": request_dict["ExperimentConfig"][ + "TrialComponentDisplayName" + ] + } + else: + request_dict.pop("ExperimentConfig", None) + @attr.s class CacheConfig: @@ -432,7 +444,7 @@ def arguments(self) -> RequestType: request_dict["HyperParameters"].pop("sagemaker_job_name", None) request_dict.pop("TrainingJobName", None) - request_dict.pop("ExperimentConfig", None) + Step._trim_experiment_config(request_dict) return request_dict @@ -663,7 +675,8 @@ def arguments(self) -> RequestType: ) request_dict.pop("TransformJobName", None) - request_dict.pop("ExperimentConfig", None) + Step._trim_experiment_config(request_dict) + return request_dict @property @@ -811,7 +824,8 @@ def arguments(self) -> RequestType: request_dict = self.processor.sagemaker_session._get_process_request(**process_args) request_dict.pop("ProcessingJobName", None) - request_dict.pop("ExperimentConfig", None) + Step._trim_experiment_config(request_dict) + return request_dict @property diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 262d0eb558..93fd439468 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -18,6 +18,8 @@ import pytest import warnings +from copy import deepcopy + from sagemaker.estimator import Estimator from sagemaker.parameter import IntegerParameter from sagemaker.transformer import Transformer @@ -244,7 +246,34 @@ def network_config(): ) -def test_processing_step_with_processor(pipeline_session, processing_input): +@pytest.mark.parametrize( + "experiment_config, expected_experiment_config", + [ + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + "TrialComponentDisplayName": "display-name", + }, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + {"TrialComponentDisplayName": "display-name"}, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + }, + None, + ), + (None, None), + ], +) +def test_processing_step_with_processor( + pipeline_session, processing_input, experiment_config, expected_experiment_config +): custom_step1 = CustomStep("TestStep") custom_step2 = CustomStep("SecondTestStep") processor = Processor( @@ -256,7 +285,7 @@ def test_processing_step_with_processor(pipeline_session, processing_input): ) with warnings.catch_warnings(record=True) as w: - step_args = processor.run(inputs=processing_input) + step_args = processor.run(inputs=processing_input, experiment_config=experiment_config) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert "Running within a PipelineSession" in str(w[-1].message) @@ -283,13 +312,21 @@ def test_processing_step_with_processor(pipeline_session, processing_input): steps=[step, custom_step1, custom_step2], sagemaker_session=pipeline_session, ) + + expected_step_arguments = deepcopy(step_args.args) + if expected_experiment_config is None: + expected_step_arguments.pop("ExperimentConfig", None) + else: + expected_step_arguments["ExperimentConfig"] = expected_experiment_config + del expected_step_arguments["ProcessingJobName"] + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyProcessingStep", "Description": "ProcessingStep description", "DisplayName": "MyProcessingStep", "Type": "Processing", "DependsOn": ["TestStep", "SecondTestStep"], - "Arguments": step_args.args, + "Arguments": expected_step_arguments, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, "PropertyFiles": [ { diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index f043048095..66a7c2fc43 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -19,6 +19,8 @@ import pytest import warnings +from copy import deepcopy + from sagemaker import Processor, Model from sagemaker.parameter import IntegerParameter from sagemaker.transformer import Transformer @@ -207,7 +209,34 @@ def hyperparameters(): return {"test-key": "test-val"} -def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters): +@pytest.mark.parametrize( + "experiment_config, expected_experiment_config", + [ + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + "TrialComponentDisplayName": "display-name", + }, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + {"TrialComponentDisplayName": "display-name"}, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + }, + None, + ), + (None, None), + ], +) +def test_training_step_with_estimator( + pipeline_session, training_input, hyperparameters, experiment_config, expected_experiment_config +): custom_step1 = CustomStep("TestStep") custom_step2 = CustomStep("SecondTestStep") enable_network_isolation = ParameterBoolean(name="enable_network_isolation") @@ -226,7 +255,9 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar with warnings.catch_warnings(record=True) as w: # TODO: remove job_name once we merge # https://github.com/aws/sagemaker-python-sdk/pull/3158/files - step_args = estimator.fit(inputs=training_input, job_name="TestJob") + step_args = estimator.fit( + inputs=training_input, job_name="TestJob", experiment_config=experiment_config + ) assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert "Running within a PipelineSession" in str(w[-1].message) @@ -247,17 +278,28 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar parameters=[enable_network_isolation, encrypt_container_traffic], sagemaker_session=pipeline_session, ) - step_args.args["EnableInterContainerTrafficEncryption"] = { + + expected_step_arguments = deepcopy(step_args.args) + + expected_step_arguments["EnableInterContainerTrafficEncryption"] = { "Get": "Parameters.encrypt_container_traffic" } - step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"} + expected_step_arguments["EnableNetworkIsolation"] = { + "Get": "Parameters.enable_network_isolation" + } + if expected_experiment_config is None: + expected_step_arguments.pop("ExperimentConfig", None) + else: + expected_step_arguments["ExperimentConfig"] = expected_experiment_config + del expected_step_arguments["TrainingJobName"] + assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Description": "TrainingStep description", "DisplayName": "MyTrainingStep", "Type": "Training", "DependsOn": ["TestStep", "SecondTestStep"], - "Arguments": step_args.args, + "Arguments": expected_step_arguments, } assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"} adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py index 3d0e25a2ee..3052a910de 100644 --- a/tests/unit/sagemaker/workflow/test_transform_step.py +++ b/tests/unit/sagemaker/workflow/test_transform_step.py @@ -18,6 +18,8 @@ import pytest import warnings +from copy import deepcopy + from sagemaker import Model, Processor from sagemaker.estimator import Estimator from sagemaker.parameter import IntegerParameter @@ -153,27 +155,108 @@ def test_transform_step_with_transformer(model_name, data, output_path, pipeline parameters=[model_name, data], sagemaker_session=pipeline_session, ) - step_args = step_args.args - step_def = json.loads(pipeline.definition())["Steps"][0] - step_args["ModelName"] = model_name.expr if is_pipeline_variable(model_name) else model_name - step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = ( + + expected_step_arguments = deepcopy(step_args.args) + expected_step_arguments["ModelName"] = ( + model_name.expr if is_pipeline_variable(model_name) else model_name + ) + expected_step_arguments["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"] = ( data.expr if is_pipeline_variable(data) else data ) - step_args["TransformOutput"]["S3OutputPath"] = ( + expected_step_arguments["TransformOutput"]["S3OutputPath"] = ( output_path.expr if is_pipeline_variable(output_path) else output_path ) + del expected_step_arguments["TransformJobName"] + + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == { + "Name": "MyTransformStep", + "Type": "Transform", + "Arguments": expected_step_arguments, + } + - del ( - step_args["ModelName"], - step_args["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], - step_args["TransformOutput"]["S3OutputPath"], +@pytest.mark.parametrize( + "experiment_config, expected_experiment_config", + [ + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + "TrialComponentDisplayName": "display-name", + }, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + {"TrialComponentDisplayName": "display-name"}, + {"TrialComponentDisplayName": "display-name"}, + ), + ( + { + "ExperimentName": "experiment-name", + "TrialName": "trial-name", + }, + None, + ), + (None, None), + ], +) +def test_transform_step_with_transformer_experiment_config( + experiment_config, expected_experiment_config, pipeline_session +): + transformer = Transformer( + model_name="my_model", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path="s3://my-bucket/my-output-path", + sagemaker_session=pipeline_session, ) - del ( - step_def["Arguments"]["ModelName"], - step_def["Arguments"]["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"], - step_def["Arguments"]["TransformOutput"]["S3OutputPath"], + transform_inputs = TransformInput(data="s3://my-bucket/my-data") + + with warnings.catch_warnings(record=True) as w: + step_args = transformer.transform( + data=transform_inputs.data, + data_type=transform_inputs.data_type, + content_type=transform_inputs.content_type, + compression_type=transform_inputs.compression_type, + split_type=transform_inputs.split_type, + input_filter=transform_inputs.input_filter, + output_filter=transform_inputs.output_filter, + join_source=transform_inputs.join_source, + model_client_config=transform_inputs.model_client_config, + experiment_config=experiment_config, + ) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Running within a PipelineSession" in str(w[-1].message) + + with warnings.catch_warnings(record=True) as w: + step = TransformStep( + name="MyTransformStep", + step_args=step_args, + ) + assert len(w) == 0 + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=pipeline_session, ) - assert step_def == {"Name": "MyTransformStep", "Type": "Transform", "Arguments": step_args} + + expected_step_arguments = deepcopy(step_args.args) + if expected_experiment_config is None: + expected_step_arguments.pop("ExperimentConfig", None) + else: + expected_step_arguments["ExperimentConfig"] = expected_experiment_config + del expected_step_arguments["TransformJobName"] + + step_def = json.loads(pipeline.definition())["Steps"][0] + assert step_def == { + "Name": "MyTransformStep", + "Type": "Transform", + "Arguments": expected_step_arguments, + } + adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert adjacency_list == {"MyTransformStep": []} From 0f59733f814bdcd58c9a2d8e4b7e21b57b8c18ca Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Mon, 15 Aug 2022 12:25:52 -0700 Subject: [PATCH 164/526] change: Add Pipeline annotation in model base class and tensorflow estimator (#3190) Model annotate update change: Add PipelineVariable annotation to composite argument of training go with model base and tf Co-authored-by: Dewen Qi --- src/sagemaker/amazon/amazon_estimator.py | 9 +- src/sagemaker/debugger/debugger.py | 45 ++++--- src/sagemaker/debugger/profiler_config.py | 9 +- src/sagemaker/drift_check_baselines.py | 22 +-- src/sagemaker/estimator.py | 17 ++- .../huggingface/training_compiler/config.py | 6 +- src/sagemaker/inputs.py | 25 ++-- src/sagemaker/metadata_properties.py | 12 +- src/sagemaker/model.py | 126 +++++++++--------- src/sagemaker/model_metrics.py | 32 +++-- .../serverless/serverless_inference_config.py | 4 +- src/sagemaker/session.py | 4 +- src/sagemaker/tensorflow/estimator.py | 16 ++- 13 files changed, 194 insertions(+), 133 deletions(-) diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 09e77d612a..eaf4644da6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -16,6 +16,7 @@ import json import logging import tempfile +from typing import Union from six.moves.urllib.parse import urlparse @@ -27,6 +28,7 @@ from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput from sagemaker.utils import sagemaker_timestamp +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline logger = logging.getLogger(__name__) @@ -304,7 +306,12 @@ class RecordSet(object): """Placeholder docstring""" def __init__( - self, s3_data, num_records, feature_dim, s3_data_type="ManifestFile", channel="train" + self, + s3_data: Union[str, PipelineVariable], + num_records: int, + feature_dim: int, + s3_data_type: Union[str, PipelineVariable] = "ManifestFile", + channel: Union[str, PipelineVariable] = "train", ): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. diff --git a/src/sagemaker/debugger/debugger.py b/src/sagemaker/debugger/debugger.py index d2d53547f1..23f7b651a3 100644 --- a/src/sagemaker/debugger/debugger.py +++ b/src/sagemaker/debugger/debugger.py @@ -24,12 +24,15 @@ from abc import ABC +from typing import Union, Optional, List, Dict + import attr import smdebug_rulesconfig as rule_configs from sagemaker import image_uris from sagemaker.utils import build_dict +from sagemaker.workflow.entities import PipelineVariable framework_name = "debugger" DEBUGGER_FLAG = "USE_SMDEBUG" @@ -311,17 +314,17 @@ def sagemaker( @classmethod def custom( cls, - name, - image_uri, - instance_type, - volume_size_in_gb, - source=None, - rule_to_invoke=None, - container_local_output_path=None, - s3_output_path=None, - other_trials_s3_input_paths=None, - rule_parameters=None, - collections_to_save=None, + name: str, + image_uri: Union[str, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable], + source: Optional[str] = None, + rule_to_invoke: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + other_trials_s3_input_paths: Optional[List[Union[str, PipelineVariable]]] = None, + rule_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collections_to_save: Optional[List["CollectionConfig"]] = None, actions=None, ): """Initialize a ``Rule`` object for a *custom* debugging rule. @@ -610,10 +613,10 @@ class DebuggerHookConfig(object): def __init__( self, - s3_output_path=None, - container_local_output_path=None, - hook_parameters=None, - collection_configs=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + hook_parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + collection_configs: Optional[List["CollectionConfig"]] = None, ): """Initialize the DebuggerHookConfig instance. @@ -679,7 +682,11 @@ def _to_request_dict(self): class TensorBoardOutputConfig(object): """Create a tensor ouput configuration object for debugging visualizations on TensorBoard.""" - def __init__(self, s3_output_path, container_local_output_path=None): + def __init__( + self, + s3_output_path: Union[str, PipelineVariable], + container_local_output_path: Optional[Union[str, PipelineVariable]] = None, + ): """Initialize the TensorBoardOutputConfig instance. Args: @@ -708,7 +715,11 @@ def _to_request_dict(self): class CollectionConfig(object): """Creates tensor collections for SageMaker Debugger.""" - def __init__(self, name, parameters=None): + def __init__( + self, + name: Union[str, PipelineVariable], + parameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + ): """Constructor for collection configuration. Args: diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 371d161bbe..807ba91e79 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -13,7 +13,10 @@ """Configuration for collecting system and framework metrics in SageMaker training jobs.""" from __future__ import absolute_import +from typing import Optional, Union + from sagemaker.debugger.framework_profile import FrameworkProfile +from sagemaker.workflow.entities import PipelineVariable class ProfilerConfig(object): @@ -26,9 +29,9 @@ class ProfilerConfig(object): def __init__( self, - s3_output_path=None, - system_monitor_interval_millis=None, - framework_profile_params=None, + s3_output_path: Optional[Union[str, PipelineVariable]] = None, + system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, + framework_profile_params: Optional[FrameworkProfile] = None, ): """Initialize a ``ProfilerConfig`` instance. diff --git a/src/sagemaker/drift_check_baselines.py b/src/sagemaker/drift_check_baselines.py index 24aa4787d0..9c3b8dbd57 100644 --- a/src/sagemaker/drift_check_baselines.py +++ b/src/sagemaker/drift_check_baselines.py @@ -13,21 +13,25 @@ """This file contains code related to drift check baselines""" from __future__ import absolute_import +from typing import Optional + +from sagemaker.model_metrics import MetricsSource, FileSource + class DriftCheckBaselines(object): """Accepts drift check baselines parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias_config_file=None, - bias_pre_training_constraints=None, - bias_post_training_constraints=None, - explainability_constraints=None, - explainability_config_file=None, + model_statistics: Optional[MetricsSource] = None, + model_constraints: Optional[MetricsSource] = None, + model_data_statistics: Optional[MetricsSource] = None, + model_data_constraints: Optional[MetricsSource] = None, + bias_config_file: Optional[FileSource] = None, + bias_pre_training_constraints: Optional[MetricsSource] = None, + bias_post_training_constraints: Optional[MetricsSource] = None, + explainability_constraints: Optional[MetricsSource] = None, + explainability_config_file: Optional[FileSource] = None, ): """Initialize a ``DriftCheckBaselines`` instance and turn parameters into dict. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 715c329f47..35c726cd0a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -50,6 +50,7 @@ validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.instance_group import InstanceGroup from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -149,7 +150,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -1586,6 +1587,8 @@ def _get_instance_type(self): for instance_group in self.instance_groups: instance_type = instance_group.instance_type + if is_pipeline_variable(instance_type): + continue match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) if match: @@ -2185,7 +2188,7 @@ def __init__( code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, - instance_groups: Optional[Dict[str, Union[str, int]]] = None, + instance_groups: Optional[List[InstanceGroup]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2880,7 +2883,15 @@ def _validate_and_set_debugger_configs(self): # Disable debugger if checkpointing is enabled by the customer if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: if self._framework_name in {"mxnet", "pytorch", "tensorflow"}: - if self.instance_count > 1 or ( + if is_pipeline_variable(self.instance_count): + logger.warning( + "SMDebug does not currently support distributed training jobs " + "with checkpointing enabled. Therefore, to allow parameterized " + "instance_count and allow to change it to any values in execution time, " + "the debugger_hook_config is disabled." + ) + self.debugger_hook_config = False + elif self.instance_count > 1 or ( hasattr(self, "distribution") and self.distribution is not None # pylint: disable=no-member ): diff --git a/src/sagemaker/huggingface/training_compiler/config.py b/src/sagemaker/huggingface/training_compiler/config.py index 07a3bcf9b7..b19fb2be2b 100644 --- a/src/sagemaker/huggingface/training_compiler/config.py +++ b/src/sagemaker/huggingface/training_compiler/config.py @@ -13,8 +13,10 @@ """Configuration for the SageMaker Training Compiler.""" from __future__ import absolute_import import logging +from typing import Union from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger(__name__) @@ -26,8 +28,8 @@ class TrainingCompilerConfig(BaseConfig): def __init__( self, - enabled=True, - debug=False, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, ): """This class initializes a ``TrainingCompilerConfig`` instance. diff --git a/src/sagemaker/inputs.py b/src/sagemaker/inputs.py index 3481c138bd..0fca307a97 100644 --- a/src/sagemaker/inputs.py +++ b/src/sagemaker/inputs.py @@ -13,8 +13,11 @@ """Amazon SageMaker channel configurations for S3 data sources and file system data sources""" from __future__ import absolute_import, print_function +from typing import Union, Optional, List import attr +from sagemaker.workflow.entities import PipelineVariable + FILE_SYSTEM_TYPES = ["FSxLustre", "EFS"] FILE_SYSTEM_ACCESS_MODES = ["ro", "rw"] @@ -29,17 +32,17 @@ class TrainingInput(object): def __init__( self, - s3_data, - distribution=None, - compression=None, - content_type=None, - record_wrapping=None, - s3_data_type="S3Prefix", - instance_groups=None, - input_mode=None, - attribute_names=None, - target_attribute_name=None, - shuffle_config=None, + s3_data: Union[str, PipelineVariable], + distribution: Optional[Union[str, PipelineVariable]] = None, + compression: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, + record_wrapping: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + instance_groups: Optional[List[Union[str, PipelineVariable]]] = None, + input_mode: Optional[Union[str, PipelineVariable]] = None, + attribute_names: Optional[List[Union[str, PipelineVariable]]] = None, + target_attribute_name: Optional[Union[str, PipelineVariable]] = None, + shuffle_config: Optional["ShuffleConfig"] = None, ): r"""Create a definition for input data used by an SageMaker training job. diff --git a/src/sagemaker/metadata_properties.py b/src/sagemaker/metadata_properties.py index 4bc77ed0ee..b25aff9168 100644 --- a/src/sagemaker/metadata_properties.py +++ b/src/sagemaker/metadata_properties.py @@ -13,16 +13,20 @@ """This file contains code related to metadata properties.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class MetadataProperties(object): """Accepts metadata properties parameters for conversion to request dict.""" def __init__( self, - commit_id=None, - repository=None, - generated_by=None, - project_id=None, + commit_id: Optional[Union[str, PipelineVariable]] = None, + repository: Optional[Union[str, PipelineVariable]] = None, + generated_by: Optional[Union[str, PipelineVariable]] = None, + project_id: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetadataProperties`` instance and turn parameters into dict. diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index a2c6da4bb7..8772fa724f 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -18,7 +18,7 @@ import logging import os import copy -from typing import List, Dict +from typing import List, Dict, Optional, Union import sagemaker from sagemaker import ( @@ -29,7 +29,11 @@ utils, git_utils, ) +from sagemaker.session import Session +from sagemaker.model_metrics import ModelMetrics from sagemaker.deprecations import removed_kwargs +from sagemaker.drift_check_baselines import DriftCheckBaselines +from sagemaker.metadata_properties import MetadataProperties from sagemaker.predictor import PredictorBase from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.transformer import Transformer @@ -37,10 +41,12 @@ from sagemaker.utils import ( unique_name_from_base, update_container_with_inference_params, + to_string, ) from sagemaker.async_inference import AsyncInferenceConfig from sagemaker.predictor_async import AsyncPredictor from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession LOGGER = logging.getLogger("sagemaker") @@ -82,23 +88,23 @@ class Model(ModelBase): def __init__( self, - image_uri, - model_data=None, - role=None, - predictor_cls=None, - env=None, - name=None, - vpc_config=None, - sagemaker_session=None, - enable_network_isolation=False, - model_kms_key=None, - image_config=None, - source_dir=None, - code_location=None, - entry_point=None, - container_log_level=logging.INFO, - dependencies=None, - git_config=None, + image_uri: Union[str, PipelineVariable], + model_data: Optional[Union[str, PipelineVariable]] = None, + role: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + vpc_config: Optional[Dict[str, List[Union[str, PipelineVariable]]]] = None, + sagemaker_session: Optional[Session] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + model_kms_key: Optional[str] = None, + image_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + source_dir: Optional[str] = None, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, ): """Initialize an SageMaker ``Model``. @@ -298,28 +304,28 @@ def __init__( @runnable_by_pipeline def register( self, - content_types, - response_types, - inference_instances=None, - transform_instances=None, - model_package_name=None, - model_package_group_name=None, - image_uri=None, - model_metrics=None, - metadata_properties=None, - marketplace_cert=False, - approval_status=None, - description=None, - drift_check_baselines=None, - customer_metadata_properties=None, - validation_specification=None, - domain=None, - task=None, - sample_payload_url=None, - framework=None, - framework_version=None, - nearest_model_name=None, - data_input_configuration=None, + content_types: List[Union[str, PipelineVariable]], + response_types: List[Union[str, PipelineVariable]], + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, + model_package_name: Optional[Union[str, PipelineVariable]] = None, + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + model_metrics: Optional[ModelMetrics] = None, + metadata_properties: Optional[MetadataProperties] = None, + marketplace_cert: bool = False, + approval_status: Optional[Union[str, PipelineVariable]] = None, + description: Optional[str] = None, + drift_check_baselines: Optional[DriftCheckBaselines] = None, + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + validation_specification: Optional[Union[str, PipelineVariable]] = None, + domain: Optional[Union[str, PipelineVariable]] = None, + task: Optional[Union[str, PipelineVariable]] = None, + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, + framework: Optional[Union[str, PipelineVariable]] = None, + framework_version: Optional[Union[str, PipelineVariable]] = None, + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, ): """Creates a model package for creating SageMaker models or listing on Marketplace. @@ -349,11 +355,11 @@ def register( metadata properties (default: None). domain (str): Domain values can be "COMPUTER_VISION", "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). - sample_payload_url (str): The S3 path where the sample payload is stored - (default: None). task (str): Task values which are supported by Inference Recommender are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). + sample_payload_url (str): The S3 path where the sample payload is stored + (default: None). framework (str): Machine learning framework of the model package container image (default: None). framework_version (str): Framework version of the Model Package Container Image @@ -421,10 +427,10 @@ def register( @runnable_by_pipeline def create( self, - instance_type: str = None, - accelerator_type: str = None, - serverless_inference_config: ServerlessInferenceConfig = None, - tags: List[Dict[str, str]] = None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, ): """Create a SageMaker Model Entity @@ -608,7 +614,7 @@ def _script_mode_env_vars(self): return { SCRIPT_PARAM_NAME.upper(): script_name or str(), DIR_PARAM_NAME.upper(): dir_name or str(), - CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level), + CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level), SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name, } @@ -1286,19 +1292,19 @@ class FrameworkModel(Model): def __init__( self, - model_data, - image_uri, - role, - entry_point, - source_dir=None, - predictor_cls=None, - env=None, - name=None, - container_log_level=logging.INFO, - code_location=None, - sagemaker_session=None, - dependencies=None, - git_config=None, + model_data: Union[str, PipelineVariable], + image_uri: Union[str, PipelineVariable], + role: str, + entry_point: str, + source_dir: Optional[str] = None, + predictor_cls: Optional[callable] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + name: Optional[str] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, **kwargs, ): """Initialize a ``FrameworkModel``. diff --git a/src/sagemaker/model_metrics.py b/src/sagemaker/model_metrics.py index acce4e13c9..83a43d3f18 100644 --- a/src/sagemaker/model_metrics.py +++ b/src/sagemaker/model_metrics.py @@ -13,20 +13,24 @@ """This file contains code related to model metrics, including metric source and file source.""" from __future__ import absolute_import +from typing import Optional, Union + +from sagemaker.workflow.entities import PipelineVariable + class ModelMetrics(object): """Accepts model metrics parameters for conversion to request dict.""" def __init__( self, - model_statistics=None, - model_constraints=None, - model_data_statistics=None, - model_data_constraints=None, - bias=None, - explainability=None, - bias_pre_training=None, - bias_post_training=None, + model_statistics: Optional["MetricsSource"] = None, + model_constraints: Optional["MetricsSource"] = None, + model_data_statistics: Optional["MetricsSource"] = None, + model_data_constraints: Optional["MetricsSource"] = None, + bias: Optional["MetricsSource"] = None, + explainability: Optional["MetricsSource"] = None, + bias_pre_training: Optional["MetricsSource"] = None, + bias_post_training: Optional["MetricsSource"] = None, ): """Initialize a ``ModelMetrics`` instance and turn parameters into dict. @@ -99,9 +103,9 @@ class MetricsSource(object): def __init__( self, - content_type, - s3_uri, - content_digest=None, + content_type: Union[str, PipelineVariable], + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``MetricsSource`` instance and turn parameters into dict. @@ -127,9 +131,9 @@ class FileSource(object): def __init__( self, - s3_uri, - content_digest=None, - content_type=None, + s3_uri: Union[str, PipelineVariable], + content_digest: Optional[Union[str, PipelineVariable]] = None, + content_type: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``FileSource`` instance and turn parameters into dict. diff --git a/src/sagemaker/serverless/serverless_inference_config.py b/src/sagemaker/serverless/serverless_inference_config.py index 39950f4f84..adc98a319a 100644 --- a/src/sagemaker/serverless/serverless_inference_config.py +++ b/src/sagemaker/serverless/serverless_inference_config.py @@ -27,8 +27,8 @@ class ServerlessInferenceConfig(object): def __init__( self, - memory_size_in_mb=2048, - max_concurrency=5, + memory_size_in_mb: int = 2048, + max_concurrency: int = 5, ): """Initialize a ServerlessInferenceConfig object for serverless inference configuration. diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 145bf41cbe..221434d7db 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2633,7 +2633,9 @@ def _create_model_request( request["VpcConfig"] = vpc_config if enable_network_isolation: - request["EnableNetworkIsolation"] = True + # enable_network_isolation may be a pipeline variable which is + # parsed in execution time + request["EnableNetworkIsolation"] = enable_network_isolation return request diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index 4db647e140..9533f475a1 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from typing import Optional, Union, Dict from packaging import version @@ -27,6 +28,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow import is_pipeline_variable from sagemaker.tensorflow.training_compiler.config import TrainingCompilerConfig +from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -41,12 +43,12 @@ class TensorFlow(Framework): def __init__( self, - py_version=None, - framework_version=None, - model_dir=None, - image_uri=None, - distribution=None, - compiler_config=None, + py_version: Optional[str] = None, + framework_version: Optional[str] = None, + model_dir: Optional[Union[str, PipelineVariable]] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + distribution: Optional[Dict[str, str]] = None, + compiler_config: Optional[TrainingCompilerConfig] = None, **kwargs, ): """Initialize a ``TensorFlow`` estimator. @@ -251,6 +253,8 @@ def _only_legacy_mode_supported(self): def _only_python_3_supported(self): """Placeholder docstring""" + if not self.framework_version: + return False return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION @classmethod From 6bddb3db7dfcaf6b864b60dbb05293f14a269c5d Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Tue, 16 Aug 2022 01:54:20 +0200 Subject: [PATCH 165/526] feature: added _AnalysisConfigGenerator for clarify (#3271) * feature: extracted analysis config generation for explainability * feature: extracted analysis config generation for bias pre_training * feature: extracted analysis config generation for bias post_training * feature: extracted analysis config generation for bias * feature: simplified job_name creation * feature: extended analysis config generator methods with common logic * feature: refactored _AnalysisConfigGenerator methods * feature: added _last_analysis_config in SageMakerClarifyProcessor * added data types in _AnalysisConfigGenerator methods * applied style formatting to fix build issues Co-authored-by: Yeldos Balgabekov --- src/sagemaker/clarify.py | 202 +++++++++++++++++++++++++------------ tests/unit/test_clarify.py | 141 +++++++++++++++++++++++++- 2 files changed, 277 insertions(+), 66 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 6590d30514..3bc2071330 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,6 +25,8 @@ import tempfile from abc import ABC, abstractmethod +from typing import List, Union + from sagemaker import image_uris, s3, utils from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor @@ -922,6 +924,7 @@ def __init__( version (str): Clarify version to use. """ # noqa E501 # pylint: disable=c0301 container_uri = image_uris.retrieve("clarify", sagemaker_session.boto_region_name, version) + self._last_analysis_config = None self.job_name_prefix = job_name_prefix super(SageMakerClarifyProcessor, self).__init__( role, @@ -983,10 +986,10 @@ def _run( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ - analysis_config["methods"]["report"] = { - "name": "report", - "title": "Analysis Report", - } + # for debugging: to access locally, i.e. without a need to look for it in an S3 bucket + self._last_analysis_config = analysis_config + logger.info("Analysis Config: %s", analysis_config) + with tempfile.TemporaryDirectory() as tmpdirname: analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") with open(analysis_config_file, "w") as f: @@ -1083,14 +1086,13 @@ def run_pre_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - analysis_config["methods"] = {"pre_training_bias": {"methods": methods}} - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Pretraining-Bias") + analysis_config = _AnalysisConfigGenerator.bias_pre_training( + data_config, data_bias_config, methods + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Pretraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1165,21 +1167,13 @@ def run_post_training_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(data_bias_config.get_config()) - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - predictor_config.update(model_config.get_predictor_config()) - analysis_config["methods"] = {"post_training_bias": {"methods": methods}} - analysis_config["predictor"] = predictor_config - _set(probability_threshold, "probability_threshold", analysis_config) - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Posttraining-Bias") + analysis_config = _AnalysisConfigGenerator.bias_post_training( + data_config, data_bias_config, model_predicted_label_config, methods, model_config + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Posttraining-Bias" + ) return self._run( data_config, analysis_config, @@ -1264,28 +1258,16 @@ def run_bias( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 - analysis_config = data_config.get_config() - analysis_config.update(bias_config.get_config()) - analysis_config["predictor"] = model_config.get_predictor_config() - if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - if predictor_config: - analysis_config["predictor"].update(predictor_config) - if probability_threshold is not None: - analysis_config["probability_threshold"] = probability_threshold - - analysis_config["methods"] = { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - } - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Bias") + analysis_config = _AnalysisConfigGenerator.bias( + data_config, + bias_config, + model_config, + model_predicted_label_config, + pre_training_methods, + post_training_methods, + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base(self.job_name_prefix or "Clarify-Bias") return self._run( data_config, analysis_config, @@ -1370,6 +1352,36 @@ def run_explainability( the Trial Component will be unassociated. * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 + analysis_config = _AnalysisConfigGenerator.explainability( + data_config, model_config, model_scores, explainability_config + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Explainability" + ) + return self._run( + data_config, + analysis_config, + wait, + logs, + job_name, + kms_key, + experiment_config, + ) + + +class _AnalysisConfigGenerator: + """Creates analysis_config objects for different type of runs.""" + + @classmethod + def explainability( + cls, + data_config: DataConfig, + model_config: ModelConfig, + model_scores: ModelPredictedLabelConfig, + explainability_config: ExplainabilityConfig, + ): + """Generates a config for Explainability""" analysis_config = data_config.get_config() predictor_config = model_config.get_predictor_config() if isinstance(model_scores, ModelPredictedLabelConfig): @@ -1406,20 +1418,84 @@ def run_explainability( explainability_methods = explainability_config.get_explainability_config() analysis_config["methods"] = explainability_methods analysis_config["predictor"] = predictor_config - if job_name is None: - if self.job_name_prefix: - job_name = utils.name_from_base(self.job_name_prefix) - else: - job_name = utils.name_from_base("Clarify-Explainability") - return self._run( - data_config, - analysis_config, - wait, - logs, - job_name, - kms_key, - experiment_config, - ) + return cls._common(analysis_config) + + @classmethod + def bias_pre_training( + cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] + ): + """Generates a config for Bias Pre Training""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "methods": {"pre_training_bias": {"methods": methods}}, + } + return cls._common(analysis_config) + + @classmethod + def bias_post_training( + cls, + data_config: DataConfig, + bias_config: BiasConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + methods: Union[str, List[str]], + model_config: ModelConfig, + ): + """Generates a config for Bias Post Training""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": {**model_config.get_predictor_config()}, + "methods": {"post_training_bias": {"methods": methods}}, + } + if model_predicted_label_config: + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + _set(probability_threshold, "probability_threshold", analysis_config) + return cls._common(analysis_config) + + @classmethod + def bias( + cls, + data_config: DataConfig, + bias_config: BiasConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + ): + """Generates a config for Bias""" + analysis_config = { + **data_config.get_config(), + **bias_config.get_config(), + "predictor": model_config.get_predictor_config(), + "methods": { + "pre_training_bias": {"methods": pre_training_methods}, + "post_training_bias": {"methods": post_training_methods}, + }, + } + if model_predicted_label_config: + ( + probability_threshold, + predictor_config, + ) = model_predicted_label_config.get_predictor_config() + if predictor_config: + analysis_config["predictor"].update(predictor_config) + _set(probability_threshold, "probability_threshold", analysis_config) + return cls._common(analysis_config) + + @staticmethod + def _common(analysis_config): + """Extends analysis config with common values""" + analysis_config["methods"]["report"] = { + "name": "report", + "title": "Analysis Report", + } + return analysis_config def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index fa437573f0..7375657944 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -29,6 +29,7 @@ SHAPConfig, TextConfig, ImageConfig, + _AnalysisConfigGenerator, ) JOB_NAME_PREFIX = "my-prefix" @@ -764,7 +765,10 @@ def test_pre_training_bias( "label_values_or_threshold": [1], "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"pre_training_bias": {"methods": "all"}}, + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + }, } mock_method.assert_called_with( data_config, @@ -827,7 +831,10 @@ def test_post_training_bias( "joinsource_name_or_index": "F4", "facet": [{"name_or_index": "F1"}], "group_variable": "F2", - "methods": {"post_training_bias": {"methods": "all"}}, + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + }, "predictor": { "model_name": "xgboost-model", "instance_type": "ml.c5.xlarge", @@ -985,7 +992,10 @@ def _run_test_explain( "grid_resolution": 20, "top_k_features": 10, } - expected_analysis_config["methods"] = expected_explanation_configs + expected_analysis_config["methods"] = { + "report": {"name": "report", "title": "Analysis Report"}, + **expected_explanation_configs, + } mock_method.assert_called_with( data_config, expected_analysis_config, @@ -1277,3 +1287,128 @@ def test_shap_with_image_config( expected_predictor_config, expected_image_config=expected_image_config, ) + + +def test_analysis_config_generator_for_explainability(data_config, model_config): + model_scores = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + SHAPConfig(), + ) + expected = { + "dataset_type": "text/csv", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): + actual = _AnalysisConfigGenerator.bias_pre_training( + data_config, data_bias_config, methods="all" + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "pre_training_bias": {"methods": "all"}, + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias_post_training( + data_config, data_bias_config, model_config +): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_post_training( + data_config, + data_bias_config, + model_predicted_label_config, + methods="all", + model_config=model_config, + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias(data_config, data_bias_config, model_config): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias( + data_config, + data_bias_config, + model_config, + model_predicted_label_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "report": {"name": "report", "title": "Analysis Report"}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected From d1431f5430f1d41ff717f30562da7c1a4761c202 Mon Sep 17 00:00:00 2001 From: keerthanvasist Date: Mon, 15 Aug 2022 17:00:21 -0700 Subject: [PATCH 166/526] documentation: Correct documentation error (#3283) --- src/sagemaker/clarify.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 3bc2071330..7f00a78268 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -65,7 +65,6 @@ def __init__( label (str): Target attribute of the model required by bias metrics. Specified as column name or index for CSV dataset or as JSONPath for JSONLines. *Required parameter* except for when the input dataset does not contain the label. - Cannot be used at the same time as ``predicted_label``. features (str): JSONPath for locating the feature columns for bias metrics if the dataset format is JSONLines. dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, @@ -105,7 +104,7 @@ def __init__( predicted_label (str or int): Predicted label of the target attribute of the model required for running bias analysis. Specified as column name or index for CSV data. Clarify uses the predicted labels directly instead of making model inference API - calls. Cannot be used at the same time as ``label``. + calls. excluded_columns (list[int] or list[str]): A list of names or indices of the columns which are to be excluded from making model inference API calls. From 7ceb79bacec9301202c1a8ef1c12682ab5bad4f3 Mon Sep 17 00:00:00 2001 From: Qingzi-Lan <83724147+Qingzi-Lan@users.noreply.github.com> Date: Tue, 16 Aug 2022 11:24:33 -0700 Subject: [PATCH 167/526] feature: Add PT 1.12 support (#3231) * feature: Add PT 1.12 support * update cgk region --- src/sagemaker/fw_utils.py | 2 + src/sagemaker/image_uri_config/pytorch.json | 78 +++++++++++++++++++-- tests/unit/test_fw_utils.py | 3 + 3 files changed, 79 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index ef99454a45..5b7b5da656 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -101,6 +101,8 @@ "1.10.2", "1.11", "1.11.0", + "1.12", + "1.12.0", ], } diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 5a2fb90202..a88f7f1c50 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -66,8 +66,9 @@ "1.7": "1.7.1", "1.8": "1.8.1", "1.9": "1.9.1", - "1.10": "1.10.0", - "1.11": "1.11.0" + "1.10": "1.10.2", + "1.11": "1.11.0", + "1.12": "1.12.0" }, "versions": { "0.4.0": { @@ -616,6 +617,40 @@ "us-west-2": "763104351884" }, "repository": "pytorch-inference" + }, + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-inference" } } }, @@ -636,8 +671,9 @@ "1.7": "1.7.1", "1.8": "1.8.1", "1.9": "1.9.1", - "1.10": "1.10.0", - "1.11": "1.11.0" + "1.10": "1.10.2", + "1.11": "1.11.0", + "1.12": "1.12.0" }, "versions": { "0.4.0": { @@ -1187,6 +1223,40 @@ "us-west-2": "763104351884" }, "repository": "pytorch-training" + }, + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" } } } diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 018255cf47..fbbc27be37 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -827,6 +827,8 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.11", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled), + ("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi), @@ -842,6 +844,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled_custom_mpi), ] for instance_type, framework_name, framework_version, py_version, distribution in good_args: fw_utils._validate_smdataparallel_args( From cb014e612a1ee2674501d2b5c39ce0de0b4101c3 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Thu, 7 Jul 2022 10:08:34 -0700 Subject: [PATCH 168/526] feature: Pipelines local mode setup Co-authored-by: Namrata Madan --- src/sagemaker/local/__init__.py | 1 + src/sagemaker/local/entities.py | 205 +++++++++++++++++ src/sagemaker/local/exceptions.py | 24 ++ src/sagemaker/local/local_session.py | 118 +++++++++- src/sagemaker/local/pipeline.py | 134 +++++++++++ src/sagemaker/local/utils.py | 34 +++ src/sagemaker/workflow/execution_variables.py | 6 + src/sagemaker/workflow/pipeline.py | 24 +- src/sagemaker/workflow/pipeline_context.py | 3 +- src/sagemaker/workflow/steps.py | 2 + tests/unit/sagemaker/local/__init__.py | 0 .../{ => sagemaker/local}/test_local_data.py | 0 .../local}/test_local_entities.py | 0 .../sagemaker/local/test_local_pipeline.py | 213 ++++++++++++++++++ .../local}/test_local_session.py | 0 .../{ => sagemaker/local}/test_local_utils.py | 0 .../unit/sagemaker/workflow/test_pipeline.py | 30 +++ .../sagemaker/workflow/test_pipeline_graph.py | 20 +- 18 files changed, 805 insertions(+), 9 deletions(-) create mode 100644 src/sagemaker/local/exceptions.py create mode 100644 src/sagemaker/local/pipeline.py create mode 100644 tests/unit/sagemaker/local/__init__.py rename tests/unit/{ => sagemaker/local}/test_local_data.py (100%) rename tests/unit/{ => sagemaker/local}/test_local_entities.py (100%) create mode 100644 tests/unit/sagemaker/local/test_local_pipeline.py rename tests/unit/{ => sagemaker/local}/test_local_session.py (100%) rename tests/unit/{ => sagemaker/local}/test_local_utils.py (100%) diff --git a/src/sagemaker/local/__init__.py b/src/sagemaker/local/__init__.py index 1cd1b222e3..7bb8cf224c 100644 --- a/src/sagemaker/local/__init__.py +++ b/src/sagemaker/local/__init__.py @@ -18,4 +18,5 @@ LocalSagemakerClient, LocalSagemakerRuntimeClient, LocalSession, + LocalPipelineSession, ) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 3b8da1b46b..f7974e31f8 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -13,17 +13,23 @@ """Placeholder docstring""" from __future__ import absolute_import +import enum import datetime import json import logging import os import tempfile import time +from uuid import uuid4 +from copy import deepcopy +from botocore.exceptions import ClientError import sagemaker.local.data + from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host from sagemaker.utils import DeferredError, get_config_value +from sagemaker.local.exceptions import StepExecutionException logger = logging.getLogger(__name__) @@ -618,6 +624,205 @@ def describe(self): return response +class _LocalPipeline(object): + """Placeholder docstring""" + + _executions = {} + + def __init__( + self, + pipeline, + pipeline_description=None, + local_session=None, + ): + from sagemaker.local import LocalSession + + self.local_session = local_session or LocalSession() + self.pipeline = pipeline + self.pipeline_description = pipeline_description + now_time = datetime.datetime.now() + self.creation_time = now_time + self.last_modified_time = now_time + + def describe(self): + """Placeholder docstring""" + response = { + "PipelineArn": self.pipeline.name, + "PipelineDefinition": self.pipeline.definition(), + "PipelineDescription": self.pipeline_description, + "PipelineName": self.pipeline.name, + "PipelineStatus": "Active", + "RoleArn": "", + "CreationTime": self.creation_time, + "LastModifiedTime": self.last_modified_time, + } + return response + + def start(self, **kwargs): + """Placeholder docstring""" + from sagemaker.local.pipeline import LocalPipelineExecutor + + execution_id = str(uuid4()) + execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs) + + self._executions[execution_id] = execution + return LocalPipelineExecutor(execution, self.local_session).execute() + + +class _LocalPipelineExecution(object): + """Placeholder docstring""" + + def __init__( + self, + execution_id, + pipeline, + PipelineParameters=None, + PipelineExecutionDescription=None, + PipelineExecutionDisplayName=None, + ): + self.pipeline = pipeline + self.pipeline_execution_name = execution_id + self.pipeline_execution_description = PipelineExecutionDescription + self.pipeline_execution_display_name = PipelineExecutionDisplayName + self.status = _LocalExecutionStatus.EXECUTING.value + self.failure_reason = None + self.creation_time = datetime.datetime.now() + self.step_execution = self._initialize_step_execution() + self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters) + + def describe(self): + """Placeholder docstring""" + response = { + "CreationTime": self.creation_time, + "LastModifiedTime": self.creation_time, + "FailureReason": self.failure_reason, + "PipelineArn": self.pipeline.name, + "PipelineExecutionArn": self.pipeline_execution_name, + "PipelineExecutionDescription": self.pipeline_execution_description, + "PipelineExecutionDisplayName": self.pipeline_execution_display_name, + "PipelineExecutionStatus": self.status, + } + filtered_response = {k: v for k, v in response.items() if v is not None} + return filtered_response + + def list_steps(self): + """Placeholder docstring""" + # TODO + + def update_execution_failure(self, step_name, failure_message): + """Mark execution as failed.""" + self.status = _LocalExecutionStatus.FAILED.value + self.failure_reason = f"Step {step_name} failed with message: {failure_message}" + logger.error("Pipeline execution failed because step %s failed.", step_name) + + def update_step_failure(self, step_name, failure_message): + """Mark step_name as failed.""" + self.step_execution.get(step_name).update_step_failure(failure_message) + + def mark_step_starting(self, step_name): + """Update step's status to EXECUTING""" + self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING + + def _initialize_step_execution(self): + """Initialize step_execution dict.""" + from sagemaker.workflow.steps import StepTypeEnum + + supported_steps_types = ( + StepTypeEnum.TRAINING, + StepTypeEnum.PROCESSING, + StepTypeEnum.TRANSFORM, + StepTypeEnum.CONDITION, + StepTypeEnum.FAIL, + ) + + step_execution = {} + for step in self.pipeline.steps: + if step.step_type not in supported_steps_types: + error_msg = self._construct_validation_exception_message( + "Step type {} is not supported in local mode.".format(step.step_type.value) + ) + raise ClientError(error_msg, "start_pipeline_execution") + step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type) + return step_execution + + def _initialize_and_validate_parameters(self, overridden_parameters): + """Initialize and validate pipeline parameters.""" + merged_parameters = {} + default_parameters = {parameter.name: parameter for parameter in self.pipeline.parameters} + if overridden_parameters is not None: + for (param_name, param_value) in overridden_parameters.items(): + if param_name not in default_parameters: + error_msg = self._construct_validation_exception_message( + "Unknown parameter '{}'".format(param_name) + ) + raise ClientError(error_msg, "start_pipeline_execution") + parameter_type = default_parameters[param_name].parameter_type + if type(param_value) != parameter_type.python_type: # pylint: disable=C0123 + error_msg = self._construct_validation_exception_message( + "Unexpected type for parameter '{}'. Expected {} but found " + "{}.".format(param_name, parameter_type.python_type, type(param_value)) + ) + raise ClientError(error_msg, "start_pipeline_execution") + merged_parameters[param_name] = param_value + for param_name, default_parameter in default_parameters.items(): + if param_name not in merged_parameters: + if default_parameter.default_value is None: + error_msg = self._construct_validation_exception_message( + "Parameter '{}' is undefined.".format(param_name) + ) + raise ClientError(error_msg, "start_pipeline_execution") + merged_parameters[param_name] = default_parameter.default_value + return merged_parameters + + @staticmethod + def _construct_validation_exception_message(exception_msg): + """Construct error response for botocore.exceptions.ClientError""" + return {"Error": {"Code": "ValidationException", "Message": exception_msg}} + + +class _LocalPipelineStepExecution(object): + """Placeholder docstring""" + + def __init__( + self, + step_name, + step_type, + last_modified_time=None, + status=None, + properties=None, + failure_reason=None, + ): + self.step_name = step_name + self.step_type = step_type + self.status = status or _LocalExecutionStatus.STARTING + self.failure_reason = failure_reason + self.properties = properties or {} + self.creation_time = datetime.datetime.now() + self.last_modified_time = last_modified_time or self.creation_time + + def update_step_properties(self, properties): + """Update pipeline step execution output properties.""" + logger.info("Successfully completed step %s.", self.step_name) + self.properties = deepcopy(properties) + self.status = _LocalExecutionStatus.SUCCEEDED.value + + def update_step_failure(self, failure_message): + """Update pipeline step execution failure status and message.""" + logger.error(failure_message) + self.failure_reason = failure_message + self.status = _LocalExecutionStatus.FAILED.value + raise StepExecutionException(self.step_name, failure_message) + + +class _LocalExecutionStatus(enum.Enum): + """Placeholder docstring""" + + STARTING = "Starting" + EXECUTING = "Executing" + SUCCEEDED = "Succeeded" + FAILED = "Failed" + + def _wait_for_serving_container(serving_port): """Placeholder docstring.""" i = 0 diff --git a/src/sagemaker/local/exceptions.py b/src/sagemaker/local/exceptions.py new file mode 100644 index 0000000000..025c50a4e3 --- /dev/null +++ b/src/sagemaker/local/exceptions.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Custom Exceptions.""" +from __future__ import absolute_import + + +class StepExecutionException(Exception): + """Exception indicating a failure while execution pipeline steps.""" + + def __init__(self, step_name, message): + """Placeholder docstring""" + super(StepExecutionException, self).__init__(message) + self.message = message + self.step_name = step_name diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index f5f86f59d5..c9f6c910bd 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -16,6 +16,7 @@ import logging import os import platform +from datetime import datetime import boto3 from botocore.exceptions import ClientError @@ -29,6 +30,7 @@ _LocalProcessingJob, _LocalTrainingJob, _LocalTransformJob, + _LocalPipeline, ) from sagemaker.session import Session from sagemaker.utils import get_config_value, _module_import_error @@ -36,7 +38,7 @@ logger = logging.getLogger(__name__) -class LocalSagemakerClient(object): +class LocalSagemakerClient(object): # pylint: disable=too-many-public-methods """A SageMakerClient that implements the API calls locally. Used for doing local training and hosting local endpoints. It still needs access to @@ -56,6 +58,7 @@ class LocalSagemakerClient(object): _models = {} _endpoint_configs = {} _endpoints = {} + _pipelines = {} def __init__(self, sagemaker_session=None): """Initialize a LocalSageMakerClient. @@ -402,6 +405,107 @@ def delete_model(self, ModelName): if ModelName in LocalSagemakerClient._models: del LocalSagemakerClient._models[ModelName] + def create_pipeline( + self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument + ): + """Create a local pipeline. + + Args: + pipeline (Pipeline): Pipeline object + pipeline_description (str): Description of the pipeline + + Returns: + Pipeline metadata (PipelineArn) + + """ + local_pipeline = _LocalPipeline( + pipeline=pipeline, + pipeline_description=pipeline_description, + local_session=self.sagemaker_session, + ) + LocalSagemakerClient._pipelines[pipeline.name] = local_pipeline + return {"PipelineArn": pipeline.name} + + def update_pipeline( + self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument + ): + """Update a local pipeline. + + Args: + pipeline (Pipeline): Pipeline object + pipeline_description (str): Description of the pipeline + + Returns: + Pipeline metadata (PipelineArn) + + """ + if pipeline.name not in LocalSagemakerClient._pipelines: + error_response = { + "Error": { + "Code": "ResourceNotFound", + "Message": "Pipeline {} does not exist".format(pipeline.name), + } + } + raise ClientError(error_response, "update_pipeline") + LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description + LocalSagemakerClient._pipelines[pipeline.name].last_modified_time = datetime.now() + return {"PipelineArn": pipeline.name} + + def describe_pipeline(self, PipelineName): + """Describe the pipeline. + + Args: + PipelineName (str): + + Returns: + Pipeline metadata (PipelineArn, PipelineDefinition, LastModifiedTime, etc) + + """ + if PipelineName not in LocalSagemakerClient._pipelines: + error_response = { + "Error": { + "Code": "ResourceNotFound", + "Message": "Pipeline {} does not exist".format(PipelineName), + } + } + raise ClientError(error_response, "describe_pipeline") + return LocalSagemakerClient._pipelines[PipelineName].describe() + + def delete_pipeline(self, PipelineName): + """Delete the local pipeline. + + Args: + PipelineName (str): + + Returns: + Pipeline metadata (PipelineArn) + + """ + if PipelineName in LocalSagemakerClient._pipelines: + del LocalSagemakerClient._pipelines[PipelineName] + return {"PipelineArn": PipelineName} + + def start_pipeline_execution(self, PipelineName, **kwargs): + """Start the pipeline. + + Args: + PipelineName (str): + + Returns: _LocalPipelineExecution object + + """ + if "ParallelismConfiguration" in kwargs: + logger.warning("Parallelism configuration is not supported in local mode.") + if PipelineName not in LocalSagemakerClient._pipelines: + error_response = { + "Error": { + "Code": "ResourceNotFound", + "Message": "Pipeline {} does not exist".format(PipelineName), + } + } + raise ClientError(error_response, "start_pipeline_execution") + return LocalSagemakerClient._pipelines[PipelineName].start(**kwargs) + class LocalSagemakerRuntimeClient(object): """A SageMaker Runtime client that calls a local endpoint only.""" @@ -535,7 +639,6 @@ def _initialize( else: self.boto_session = boto_session - # self.boto_session = boto_session or boto3.Session() self._region_name = self.boto_session.region_name if self._region_name is None: @@ -610,3 +713,14 @@ def __init__(self, fileUri, content_type=None): if content_type is not None: self.config["ContentType"] = content_type + + +class LocalPipelineSession(LocalSession): + """Class representing a local session for SageMaker Pipelines executions.""" + + def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False): + super().__init__( + boto_session=boto_session, + s3_endpoint_url=s3_endpoint_url, + disable_local_code=disable_local_code, + ) diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py new file mode 100644 index 0000000000..6590ae7e58 --- /dev/null +++ b/src/sagemaker/local/pipeline.py @@ -0,0 +1,134 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Local Pipeline Executor""" +from __future__ import absolute_import + +import logging +from copy import deepcopy +from datetime import datetime + +from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.parameters import Parameter +from sagemaker.workflow.functions import Join, JsonGet +from sagemaker.workflow.properties import Properties +from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables +from sagemaker.workflow.pipeline import PipelineGraph +from sagemaker.local.exceptions import StepExecutionException +from sagemaker.local.utils import get_using_dot_notation + +logger = logging.getLogger(__name__) + +PRIMITIVES = (str, int, bool, float) + + +class LocalPipelineExecutor(object): + """An executor that executes SageMaker Pipelines locally.""" + + def __init__(self, execution, sagemaker_session): + """Initialize StepExecutor. + + Args: + sagemaker_session (sagemaker.session.Session): a session to use to read configurations + from, and use its boto client. + """ + self.sagemaker_session = sagemaker_session + self.execution = execution + self.pipeline_dag = PipelineGraph.from_pipeline(self.execution.pipeline) + + def execute(self): + """Execute a local pipeline.""" + try: + for step in self.pipeline_dag: + self.execute_step(step) + except StepExecutionException as e: + self.execution.update_execution_failure(e.step_name, e.message) + return self.execution + + def execute_step(self, step): + """Execute a local pipeline step.""" + self.execution.mark_step_starting(step.name) + step_arguments = self.evaluate_step_arguments(step) # noqa: F841; pylint: disable=W0612 + # TODO execute step + + def evaluate_step_arguments(self, step): + """Parses and evaluate step arguments.""" + + def _parse_arguments(obj): + if isinstance(obj, dict): + obj_copy = deepcopy(obj) + for k, v in obj.items(): + if isinstance(v, dict): + obj_copy[k] = _parse_arguments(v) + elif isinstance(v, list): + list_copy = [] + for item in v: + list_copy.append(_parse_arguments(item)) + obj_copy[k] = list_copy + elif isinstance(v, PipelineVariable): + obj_copy[k] = self.evaluate_pipeline_variable(v, step.name) + return obj_copy + return obj + + return _parse_arguments(step.arguments) + + def evaluate_pipeline_variable(self, pipeline_variable, step_name): + """Evaluate pipeline variable runtime value.""" + value = None + if isinstance(pipeline_variable, PRIMITIVES): + value = pipeline_variable + elif isinstance(pipeline_variable, Parameter): + value = self.execution.pipeline_parameters.get(pipeline_variable.name) + elif isinstance(pipeline_variable, Join): + evaluated = [ + self.evaluate_pipeline_variable(v, step_name) for v in pipeline_variable.values + ] + value = pipeline_variable.on.join(evaluated) + elif isinstance(pipeline_variable, Properties): + value = self._evaluate_property_reference(pipeline_variable, step_name) + elif isinstance(pipeline_variable, ExecutionVariable): + value = self._evaluate_execution_variable(pipeline_variable) + elif isinstance(pipeline_variable, JsonGet): + # TODO + raise NotImplementedError + else: + self.execution.update_step_failure( + step_name, f"Unrecognized pipeline variable {pipeline_variable.expr}." + ) + + if value is None: + self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.") + return value + + def _evaluate_property_reference(self, pipeline_variable, step_name): + """Evaluate property reference runtime value.""" + try: + referenced_step_name = pipeline_variable.step_name + step_properties = self.execution.step_execution.get(referenced_step_name).properties + return get_using_dot_notation(step_properties, pipeline_variable.path) + except (KeyError, IndexError): + self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.") + + def _evaluate_execution_variable(self, pipeline_variable): + """Evaluate pipeline execution variable runtime value.""" + if pipeline_variable in (ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_ARN): + return self.execution.pipeline.name + if pipeline_variable in ( + ExecutionVariables.PIPELINE_EXECUTION_ID, + ExecutionVariables.PIPELINE_EXECUTION_ARN, + ): + return self.execution.pipeline_execution_name + if pipeline_variable == ExecutionVariables.START_DATETIME: + return self.execution.creation_time + if pipeline_variable == ExecutionVariables.CURRENT_DATETIME: + return datetime.now() + return None diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 1b3ea155e1..3031a407fe 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -17,6 +17,7 @@ import shutil import subprocess import json +import re from distutils.dir_util import copy_tree from six.moves.urllib.parse import urlparse @@ -152,3 +153,36 @@ def get_docker_host(): if parsed_url.hostname and parsed_url.scheme == "tcp": return parsed_url.hostname return "localhost" + + +def get_using_dot_notation(dictionary, keys): + """Extract `keys` from dictionary where keys is a string in dot notation. + + Args: + dictionary (Dict) + keys (str) + + Returns: + Nested object within dictionary as defined by "keys" + + Raises: + KeyError or IndexError if the provided key does not exist in input dictionary + """ + if keys is None: + return dictionary + split_keys = keys.split(".", 1) + key = split_keys[0] + rest = None + if len(split_keys) > 1: + rest = split_keys[1] + list_accessor = re.search(r"(\w+)\[(\d+)]", key) + if list_accessor: + key = list_accessor.group(1) + list_index = int(list_accessor.group(2)) + return get_using_dot_notation(dictionary[key][list_index], rest) + dict_accessor = re.search(r"(\w+)\[['\"](\S+)['\"]]", key) + if dict_accessor: + key = dict_accessor.group(1) + inner_key = dict_accessor.group(2) + return get_using_dot_notation(dictionary[key][inner_key], rest) + return get_using_dot_notation(dictionary[key], rest) diff --git a/src/sagemaker/workflow/execution_variables.py b/src/sagemaker/workflow/execution_variables.py index 59ad1733ad..53a970f628 100644 --- a/src/sagemaker/workflow/execution_variables.py +++ b/src/sagemaker/workflow/execution_variables.py @@ -31,6 +31,12 @@ def __init__(self, name: str): """ self.name = name + def __eq__(self, other): + """Override default equals method""" + if not isinstance(other, ExecutionVariable): + return NotImplemented + return self.name == other.name + def to_string(self) -> PipelineVariable: """Prompt the pipeline to convert the pipeline variable to String in runtime diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 275d952f81..95bff2957b 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -15,6 +15,7 @@ import json +import logging from copy import deepcopy from typing import Any, Dict, List, Sequence, Union, Optional @@ -25,6 +26,7 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags from sagemaker.session import Session +from sagemaker.local import LocalSession from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( @@ -41,6 +43,9 @@ from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.utilities import list_to_request +from sagemaker.workflow.pipeline_context import LocalPipelineSession + +logger = logging.getLogger(__name__) _DEFAULT_EXPERIMENT_CFG = PipelineExperimentConfig( ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_EXECUTION_ID @@ -123,6 +128,10 @@ def create( Returns: A response dict from the service. """ + if self.sagemaker_session.local_mode: + if parallelism_config: + logger.warning("Pipeline parallelism config is not supported in the local mode.") + return self.sagemaker_session.sagemaker_client.create_pipeline(self, description) tags = _append_project_tags(tags) kwargs = self._create_args(role_arn, description, parallelism_config) update_args( @@ -154,7 +163,9 @@ def _create_args( # If pipeline definition is large, upload to S3 bucket and # provide PipelineDefinitionS3Location to request instead. - if len(pipeline_definition.encode("utf-8")) < 1024 * 100: + if len(pipeline_definition.encode("utf-8")) < 1024 * 100 or isinstance( + self.sagemaker_session, (LocalSession, LocalPipelineSession) + ): kwargs["PipelineDefinition"] = pipeline_definition else: desired_s3_uri = s3.s3_path_join( @@ -203,8 +214,14 @@ def update( Returns: A response dict from the service. """ + if self.sagemaker_session.local_mode: + if parallelism_config: + logger.warning("Pipeline parallelism config is not supported in the local mode.") + return self.sagemaker_session.sagemaker_client.update_pipeline(self, description) + self._step_map = dict() _generate_step_map(self.steps, self._step_map) + kwargs = self._create_args(role_arn, description, parallelism_config) return self.sagemaker_session.sagemaker_client.update_pipeline(**kwargs) @@ -289,11 +306,14 @@ def start( kwargs = dict(PipelineName=self.name) update_args( kwargs, - PipelineParameters=format_start_parameters(parameters), PipelineExecutionDescription=execution_description, PipelineExecutionDisplayName=execution_display_name, ParallelismConfiguration=parallelism_config, ) + if self.sagemaker_session.local_mode: + update_args(kwargs, PipelineParameters=parameters) + return self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs) + update_args(kwargs, PipelineParameters=format_start_parameters(parameters)) response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs) return _PipelineExecution( arn=response["PipelineExecutionArn"], diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 341e123be0..3a9feb65e4 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -19,6 +19,7 @@ from typing import Dict, Optional from sagemaker.session import Session, SessionSettings +from sagemaker.local import LocalPipelineSession class _StepArguments: @@ -170,7 +171,7 @@ def runnable_by_pipeline(run_func): @wraps(run_func) def wrapper(*args, **kwargs): self_instance = args[0] - if isinstance(self_instance.sagemaker_session, PipelineSession): + if isinstance(self_instance.sagemaker_session, (PipelineSession, LocalPipelineSession)): run_func_params = inspect.signature(run_func).parameters arg_list = list(args) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index e979657bd4..e04d51c946 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -219,6 +219,8 @@ def _get_step_name_from_str( """Convert a Step or StepCollection name input to step name.""" from sagemaker.workflow.step_collections import StepCollection + if str_input not in step_map: + raise ValueError(f"Step {str_input} is undefined.") if isinstance(step_map[str_input], StepCollection): return step_map[str_input].steps[-1].name return str_input diff --git a/tests/unit/sagemaker/local/__init__.py b/tests/unit/sagemaker/local/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/test_local_data.py b/tests/unit/sagemaker/local/test_local_data.py similarity index 100% rename from tests/unit/test_local_data.py rename to tests/unit/sagemaker/local/test_local_data.py diff --git a/tests/unit/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py similarity index 100% rename from tests/unit/test_local_entities.py rename to tests/unit/sagemaker/local/test_local_entities.py diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py new file mode 100644 index 0000000000..3e4f6dae3f --- /dev/null +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -0,0 +1,213 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from botocore.exceptions import ClientError + +from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.steps import CreateModelStep +from sagemaker.model import Model +from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.functions import Join +from sagemaker.local.local_session import LocalSession +from sagemaker.local.pipeline import LocalPipelineExecutor, StepExecutionException +from sagemaker.local.entities import _LocalPipelineExecution +from tests.unit.sagemaker.workflow.helpers import CustomStep + +STRING_PARAMETER = ParameterString("MyStr", "DefaultParameter") +INPUT_STEP = CustomStep(name="InputStep") + + +@pytest.fixture() +def local_sagemaker_session(): + return LocalSession() + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +def test_evaluate_parameter(local_sagemaker_session): + step = CustomStep(name="MyStep", input_data=STRING_PARAMETER) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline, {"MyStr": "test_string"}) + evaluated_args = LocalPipelineExecutor( + execution, local_sagemaker_session + ).evaluate_step_arguments(step) + assert evaluated_args["input_data"] == "test_string" + + +def test_evaluate_parameter_undefined(local_sagemaker_session, role_arn): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=local_sagemaker_session, + ) + with pytest.raises(ClientError) as error: + pipeline.create(role_arn, "test pipeline") + pipeline.start() + assert f"Parameter '{parameter.name}' is undefined." in str(error.value) + + +def test_evaluate_parameter_unknown(local_sagemaker_session, role_arn): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=local_sagemaker_session, + ) + with pytest.raises(ClientError) as error: + pipeline.create(role_arn, "test pipeline") + pipeline.start({"MyStr": "test-test", "UnknownParameterFoo": "foo"}) + assert "Unknown parameter 'UnknownParameterFoo'" in str(error.value) + + +def test_evaluate_parameter_wrong_type(local_sagemaker_session, role_arn): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=local_sagemaker_session, + ) + with pytest.raises(ClientError) as error: + pipeline.create(role_arn, "test pipeline") + pipeline.start({"MyStr": True}) + assert ( + f"Unexpected type for parameter '{parameter.name}'. Expected " + f"{parameter.parameter_type.python_type} but found {type(True)}." in str(error.value) + ) + + +@pytest.mark.parametrize( + "property_reference, expected", + [ + (INPUT_STEP.properties.TrainingJobArn, "my-training-arn"), + (INPUT_STEP.properties.ExperimentConfig.TrialName, "trial-bar"), + (INPUT_STEP.properties.FinalMetricDataList[0].Value, 24), + (INPUT_STEP.properties.FailureReason, "Error: bad input!"), + (INPUT_STEP.properties.AlgorithmSpecification.AlgorithmName, "fooAlgorithm"), + (INPUT_STEP.properties.AlgorithmSpecification.MetricDefinitions[0].Name, "mse"), + (INPUT_STEP.properties.Environment["max-depth"], "10"), + ], +) +def test_evaluate_property_reference(local_sagemaker_session, property_reference, expected): + step = CustomStep(name="MyStep", input_data=property_reference) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[INPUT_STEP, step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution[INPUT_STEP.name].properties = { + "AlgorithmSpecification": { + "AlgorithmName": "fooAlgorithm", + "MetricDefinitions": [{"Name": "mse", "Regex": ".*MeanSquaredError.*"}], + }, + "TrainingJobArn": "my-training-arn", + "FinalMetricDataList": [{"MetricName": "mse", "Timestamp": 1656281030, "Value": 24}], + "ExperimentConfig": { + "ExperimentName": "my-exp", + "TrialComponentDisplayName": "trial-component-foo", + "TrialName": "trial-bar", + }, + "Environment": {"max-depth": "10"}, + "FailureReason": "Error: bad input!", + } + evaluated_args = LocalPipelineExecutor( + execution, local_sagemaker_session + ).evaluate_step_arguments(step) + assert evaluated_args["input_data"] == expected + + +def test_evaluate_property_reference_undefined(local_sagemaker_session): + step = CustomStep(name="MyStep", input_data=INPUT_STEP.properties.FailureReason) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[INPUT_STEP, step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution[INPUT_STEP.name].properties = {"TrainingJobArn": "my-training-arn"} + with pytest.raises(StepExecutionException) as e: + LocalPipelineExecutor(execution, local_sagemaker_session).evaluate_step_arguments(step) + assert f"{INPUT_STEP.properties.FailureReason.expr} is undefined." in str(e.value) + + +@pytest.mark.parametrize( + "join_value, expected", + [ + (ExecutionVariables.PIPELINE_NAME, "blah-MyPipeline-blah"), + (STRING_PARAMETER, "blah-DefaultParameter-blah"), + (INPUT_STEP.properties.TrainingJobArn, "blah-my-training-arn-blah"), + (Join(on=".", values=["test1", "test2", "test3"]), "blah-test1.test2.test3-blah"), + ( + Join(on=".", values=["test", ExecutionVariables.PIPELINE_NAME, "test"]), + "blah-test.MyPipeline.test-blah", + ), + ], +) +def test_evaluate_join_function(local_sagemaker_session, join_value, expected): + step = CustomStep(name="TestStep", input_data=Join(on="-", values=["blah", join_value, "blah"])) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[INPUT_STEP, step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution["InputStep"].properties = {"TrainingJobArn": "my-training-arn"} + evaluated_args = LocalPipelineExecutor( + execution, local_sagemaker_session + ).evaluate_step_arguments(step) + assert evaluated_args["input_data"] == expected + + +def test_execute_unsupported_step_type(role_arn, local_sagemaker_session): + step = CreateModelStep( + name="MyRegisterModelStep", + model=Model(image_uri="mock_image_uri"), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[step], + sagemaker_session=local_sagemaker_session, + ) + create_pipeline_response = pipeline.create(role_arn, "test pipeline") + assert create_pipeline_response["PipelineArn"] == "MyPipeline" + with pytest.raises(ClientError) as e: + pipeline.start() + assert f"Step type {step.step_type.value} is not supported in local mode." in str(e.value) diff --git a/tests/unit/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py similarity index 100% rename from tests/unit/test_local_session.py rename to tests/unit/sagemaker/local/test_local_session.py diff --git a/tests/unit/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py similarity index 100% rename from tests/unit/test_local_utils.py rename to tests/unit/sagemaker/local/test_local_utils.py diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 5cd94dd76a..459b12e157 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -32,6 +32,7 @@ ) from sagemaker.workflow.step_collections import StepCollection from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep +from sagemaker.local import LocalSession @pytest.fixture @@ -43,6 +44,7 @@ def role_arn(): def sagemaker_session_mock(): session_mock = Mock() session_mock.default_bucket = Mock(name="default_bucket", return_value="s3_bucket") + session_mock.local_mode = False return session_mock @@ -466,3 +468,31 @@ def _generate_large_pipeline_steps(input_data: object): for i in range(2000): steps.append(CustomStep(name=f"MyStep{i}", input_data=input_data)) return steps + + +def test_local_pipeline(): + parameter = ParameterString("MyStr", default_value="test") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)], + sagemaker_session=LocalSession(), + ) + pipeline.create("dummy-role", "pipeline-description") + + pipeline_describe_response1 = pipeline.describe() + assert pipeline_describe_response1["PipelineArn"] == "MyPipeline" + assert pipeline_describe_response1["PipelineDefinition"] == pipeline.definition() + assert pipeline_describe_response1["PipelineDescription"] == "pipeline-description" + + pipeline.update("dummy-role", "pipeline-description-2") + pipeline_describe_response2 = pipeline.describe() + assert pipeline_describe_response2["PipelineDescription"] == "pipeline-description-2" + assert ( + pipeline_describe_response2["CreationTime"] + != pipeline_describe_response2["LastModifiedTime"] + ) + + pipeline_execution_describe_response = pipeline.start().describe() + assert pipeline_execution_describe_response["PipelineArn"] == "MyPipeline" + assert pipeline_execution_describe_response["PipelineExecutionArn"] is not None diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py index 003dd8d048..2450adfe8a 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_graph.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -84,6 +84,18 @@ def test_pipeline_duplicate_step_name_in_step_collection(sagemaker_session_mock) assert "Pipeline steps cannot have duplicate names." in str(error.value) +def test_pipeline_depends_on_undefined(sagemaker_session_mock): + custom_step = CustomStep(name="foo-1", depends_on=["undefined_step"]) + pipeline = Pipeline( + name="MyPipeline", + steps=[custom_step], + sagemaker_session=sagemaker_session_mock, + ) + with pytest.raises(ValueError) as error: + PipelineGraph.from_pipeline(pipeline) + assert "Step undefined_step is undefined." in str(error.value) + + def test_pipeline_graph_acyclic(sagemaker_session_mock): step_a = CustomStep(name="stepA") step_b = CustomStep(name="stepB") @@ -111,7 +123,7 @@ def test_pipeline_graph_acyclic(sagemaker_session_mock): _verify_pipeline_graph_traversal(pipeline_graph) -def test_pipeline_graph_acyclic_with_condition_step_explicit_dependency(sagemaker_session_mock): +def test_pipeline_graph_with_condition_step_explicit_dependency(sagemaker_session_mock): custom_step = CustomStep(name="TestStep") if_step = CustomStep(name="IfStep") else_step = CustomStep(name="ElseStep") @@ -138,7 +150,7 @@ def test_pipeline_graph_acyclic_with_condition_step_explicit_dependency(sagemake _verify_pipeline_graph_traversal(pipeline_graph) -def test_pipeline_graph_acyclic_with_condition_step_property_reference_dependency( +def test_pipeline_graph_with_condition_step_property_reference_dependency( sagemaker_session_mock, ): custom_step = CustomStep(name="TestStep") @@ -162,7 +174,7 @@ def test_pipeline_graph_acyclic_with_condition_step_property_reference_dependenc _verify_pipeline_graph_traversal(pipeline_graph) -def test_pipeline_graph_acyclic_with_step_collection_explicit_dependency(sagemaker_session_mock): +def test_pipeline_graph_with_step_collection_explicit_dependency(sagemaker_session_mock): custom_step1 = CustomStep(name="TestStep") custom_step_collection = CustomStepCollection( name="TestStepCollection", depends_on=[custom_step1] @@ -187,7 +199,7 @@ def test_pipeline_graph_acyclic_with_step_collection_explicit_dependency(sagemak _verify_pipeline_graph_traversal(pipeline_graph) -def test_pipeline_graph_acyclic_with_step_collection_property_reference_dependency( +def test_pipeline_graph_with_step_collection_property_reference_dependency( sagemaker_session_mock, ): custom_step_collection = CustomStepCollection(name="TestStepCollection") From b953d2fa7a767b5473dac0eaa047ed1b2894df1c Mon Sep 17 00:00:00 2001 From: Ao Guo <72373287+aoguo64@users.noreply.github.com> Date: Fri, 15 Jul 2022 16:39:51 -0700 Subject: [PATCH 169/526] feature: local mode executor implementation Co-authored-by: Ao Guo --- src/sagemaker/local/entities.py | 28 +- src/sagemaker/local/pipeline.py | 374 ++++++++- src/sagemaker/workflow/condition_step.py | 2 +- src/sagemaker/workflow/pipeline.py | 16 +- .../sagemaker/local/test_local_pipeline.py | 720 +++++++++++++++++- .../sagemaker/workflow/test_pipeline_graph.py | 118 ++- 6 files changed, 1211 insertions(+), 47 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index f7974e31f8..3d9ddbf77e 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -687,8 +687,10 @@ def __init__( self.status = _LocalExecutionStatus.EXECUTING.value self.failure_reason = None self.creation_time = datetime.datetime.now() - self.step_execution = self._initialize_step_execution() + self.step_execution = {} + self._initialize_step_execution(self.pipeline.steps) self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters) + self.blockout_steps = {} def describe(self): """Placeholder docstring""" @@ -709,21 +711,29 @@ def list_steps(self): """Placeholder docstring""" # TODO + def update_execution_success(self): + """Mark execution as succeeded.""" + self.status = _LocalExecutionStatus.SUCCEEDED.value + def update_execution_failure(self, step_name, failure_message): """Mark execution as failed.""" self.status = _LocalExecutionStatus.FAILED.value self.failure_reason = f"Step {step_name} failed with message: {failure_message}" logger.error("Pipeline execution failed because step %s failed.", step_name) + def update_step_properties(self, step_name, step_properties): + """Update pipeline step execution output properties.""" + self.step_execution.get(step_name).update_step_properties(step_properties) + def update_step_failure(self, step_name, failure_message): """Mark step_name as failed.""" self.step_execution.get(step_name).update_step_failure(failure_message) - def mark_step_starting(self, step_name): + def mark_step_executing(self, step_name): """Update step's status to EXECUTING""" - self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING + self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING.value - def _initialize_step_execution(self): + def _initialize_step_execution(self, steps): """Initialize step_execution dict.""" from sagemaker.workflow.steps import StepTypeEnum @@ -735,15 +745,15 @@ def _initialize_step_execution(self): StepTypeEnum.FAIL, ) - step_execution = {} - for step in self.pipeline.steps: + for step in steps: if step.step_type not in supported_steps_types: error_msg = self._construct_validation_exception_message( "Step type {} is not supported in local mode.".format(step.step_type.value) ) raise ClientError(error_msg, "start_pipeline_execution") - step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type) - return step_execution + self.step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type) + if step.step_type == StepTypeEnum.CONDITION: + self._initialize_step_execution(step.if_steps + step.else_steps) def _initialize_and_validate_parameters(self, overridden_parameters): """Initialize and validate pipeline parameters.""" @@ -794,7 +804,7 @@ def __init__( ): self.step_name = step_name self.step_type = step_type - self.status = status or _LocalExecutionStatus.STARTING + self.status = status or _LocalExecutionStatus.STARTING.value self.failure_reason = failure_reason self.properties = properties or {} self.creation_time = datetime.datetime.now() diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index 6590ae7e58..1f8c3d86f9 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -12,11 +12,15 @@ # language governing permissions and limitations under the License. """Local Pipeline Executor""" from __future__ import absolute_import +from abc import ABC, abstractmethod import logging from copy import deepcopy from datetime import datetime +from typing import Dict, List +from sagemaker.workflow.conditions import ConditionTypeEnum +from sagemaker.workflow.steps import StepTypeEnum, Step from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.functions import Join, JsonGet @@ -25,10 +29,18 @@ from sagemaker.workflow.pipeline import PipelineGraph from sagemaker.local.exceptions import StepExecutionException from sagemaker.local.utils import get_using_dot_notation +from sagemaker.utils import unique_name_from_base logger = logging.getLogger(__name__) PRIMITIVES = (str, int, bool, float) +BINARY_CONDITION_TYPES = ( + ConditionTypeEnum.EQ.value, + ConditionTypeEnum.GT.value, + ConditionTypeEnum.GTE.value, + ConditionTypeEnum.LT.value, + ConditionTypeEnum.LTE.value, +) class LocalPipelineExecutor(object): @@ -44,42 +56,48 @@ def __init__(self, execution, sagemaker_session): self.sagemaker_session = sagemaker_session self.execution = execution self.pipeline_dag = PipelineGraph.from_pipeline(self.execution.pipeline) + self.local_sagemaker_client = self.sagemaker_session.sagemaker_client + self.blockout_steps = set() + self._step_executor_factory = _StepExecutorFactory(self) def execute(self): """Execute a local pipeline.""" try: for step in self.pipeline_dag: - self.execute_step(step) + if step.name not in self.blockout_steps: + self._execute_step(step) except StepExecutionException as e: self.execution.update_execution_failure(e.step_name, e.message) + else: + self.execution.update_execution_success() return self.execution - def execute_step(self, step): + def _execute_step(self, step): """Execute a local pipeline step.""" - self.execution.mark_step_starting(step.name) - step_arguments = self.evaluate_step_arguments(step) # noqa: F841; pylint: disable=W0612 - # TODO execute step + self.execution.mark_step_executing(step.name) + step_properties = self._step_executor_factory.get(step).execute() + self.execution.update_step_properties(step.name, step_properties) def evaluate_step_arguments(self, step): """Parses and evaluate step arguments.""" + return self._parse_arguments(step.arguments, step.name) - def _parse_arguments(obj): - if isinstance(obj, dict): - obj_copy = deepcopy(obj) - for k, v in obj.items(): - if isinstance(v, dict): - obj_copy[k] = _parse_arguments(v) - elif isinstance(v, list): - list_copy = [] - for item in v: - list_copy.append(_parse_arguments(item)) - obj_copy[k] = list_copy - elif isinstance(v, PipelineVariable): - obj_copy[k] = self.evaluate_pipeline_variable(v, step.name) - return obj_copy - return obj - - return _parse_arguments(step.arguments) + def _parse_arguments(self, obj, step_name): + """Parse and evaluate arguments field""" + if isinstance(obj, dict): + obj_copy = deepcopy(obj) + for k, v in obj.items(): + if isinstance(v, dict): + obj_copy[k] = self._parse_arguments(v, step_name) + elif isinstance(v, list): + list_copy = [] + for item in v: + list_copy.append(self._parse_arguments(item, step_name)) + obj_copy[k] = list_copy + elif isinstance(v, PipelineVariable): + obj_copy[k] = self.evaluate_pipeline_variable(v, step_name) + return obj_copy + return obj def evaluate_pipeline_variable(self, pipeline_variable, step_name): """Evaluate pipeline variable runtime value.""" @@ -120,7 +138,10 @@ def _evaluate_property_reference(self, pipeline_variable, step_name): def _evaluate_execution_variable(self, pipeline_variable): """Evaluate pipeline execution variable runtime value.""" - if pipeline_variable in (ExecutionVariables.PIPELINE_NAME, ExecutionVariables.PIPELINE_ARN): + if pipeline_variable in ( + ExecutionVariables.PIPELINE_NAME, + ExecutionVariables.PIPELINE_ARN, + ): return self.execution.pipeline.name if pipeline_variable in ( ExecutionVariables.PIPELINE_EXECUTION_ID, @@ -132,3 +153,310 @@ def _evaluate_execution_variable(self, pipeline_variable): if pipeline_variable == ExecutionVariables.CURRENT_DATETIME: return datetime.now() return None + + +class _StepExecutor(ABC): + """An abstract base class for step executors running steps locally""" + + def __init__(self, pipeline_executor: LocalPipelineExecutor, step: Step): + self.pipline_executor = pipeline_executor + self.step = step + + @abstractmethod + def execute(self) -> Dict: + """Execute a pipeline step locally + + Returns: + A dictionary as properties of the current step + """ + + def _convert_list_to_dict(self, dictionary: dict, path_to_list: str, reducing_key: str): + """Convert list into dictionary using a field inside list elements as the keys. + + Raises RuntimeError if given list not able to be converted into a map based on given key. + """ + + try: + rest = get_using_dot_notation(dictionary, path_to_list) + except (KeyError, IndexError, TypeError): + raise RuntimeError("%s does not exist in %s" % path_to_list, dictionary) + if not isinstance(rest, list): + raise RuntimeError( + "%s of type %s is not a list to be converted into a dictionary!" % rest, + type(rest), + ) + converted_map = {} + for element in rest: + if not isinstance(element, dict): + raise RuntimeError( + "Cannot convert element of type %s into dictionary entry" % type(element) + ) + converted_map[element[reducing_key]] = element + return converted_map + + +class _TrainingStepExecutor(_StepExecutor): + """Executor class to execute TrainingStep locally""" + + def execute(self): + job_name = unique_name_from_base(self.step.name) + step_arguments = self.pipline_executor.evaluate_step_arguments(self.step) + try: + self.pipline_executor.local_sagemaker_client.create_training_job( + job_name, **step_arguments + ) + return self.pipline_executor.local_sagemaker_client.describe_training_job(job_name) + except Exception as e: # pylint: disable=W0703 + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + ) + + +class _ProcessingStepExecutor(_StepExecutor): + """Executor class to execute ProcessingStep locally""" + + def execute(self): + job_name = unique_name_from_base(self.step.name) + step_arguments = self.pipline_executor.evaluate_step_arguments(self.step) + try: + self.pipline_executor.local_sagemaker_client.create_processing_job( + job_name, **step_arguments + ) + job_describe_response = ( + self.pipline_executor.local_sagemaker_client.describe_processing_job(job_name) + ) + job_describe_response["ProcessingOutputConfig"]["Outputs"] = self._convert_list_to_dict( + job_describe_response, "ProcessingOutputConfig.Outputs", "OutputName" + ) + job_describe_response["ProcessingInputs"] = self._convert_list_to_dict( + job_describe_response, "ProcessingInputs", "InputName" + ) + return job_describe_response + + except Exception as e: # pylint: disable=W0703 + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + ) + + +class _ConditionStepExecutor(_StepExecutor): + """Executor class to execute ConditionStep locally""" + + def execute(self): + def _blockout_all_downstream_steps(steps: List[Step]): + step_to_blockout = set() + for step in steps: + step_to_blockout.update( + self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step.name) + ) + self.pipline_executor.blockout_steps.update(step_to_blockout) + + if_steps = self.step.if_steps + else_steps = self.step.else_steps + step_only_arguments = self.pipline_executor._parse_arguments( + self.step.step_only_arguments, self.step.name + ) + + outcome = self._evaluate_conjunction(step_only_arguments["Conditions"]) + + if not outcome: + _blockout_all_downstream_steps(if_steps) + else: + _blockout_all_downstream_steps(else_steps) + + return dict(Outcome=outcome) + + def _evaluate_conjunction(self, conditions: List[Dict]) -> bool: + """Evaluate conditions of current conditionStep. + + Args: + List of dictionaries representing conditions as request + + Returns: + True if the conjunction expression is true, + False otherwise. + """ + for condition in conditions: + if not self._resolve_condition(condition): + return False + return True + + def _resolve_condition(self, condition: dict) -> bool: + """Resolve given condition. + + Args: + Dictionary representing given condition as request + + Returns: + True if given condition evaluated as true, + False otherwise. + """ + + condition_type = condition["Type"] + outcome = None + if condition_type in BINARY_CONDITION_TYPES: + outcome = self._resolve_binary_condition(condition, condition_type) + elif condition_type == ConditionTypeEnum.NOT.value: + outcome = self._resolve_not_condition(condition) + elif condition_type == ConditionTypeEnum.OR.value: + outcome = self._resolve_or_condition(condition) + elif condition_type == ConditionTypeEnum.IN.value: + outcome = self._resolve_in_condition(condition) + else: + raise NotImplementedError("Condition of type [%s] is not supported." % condition_type) + + return outcome + + def _resolve_binary_condition(self, binary_condition: dict, binary_condition_type: str): + """Resolve given binary condition. + + Args: + Dictionary representing given binary condition as request + + Returns: + True if given binary condition evaluated as true, + False otherwise. + """ + + left_value = binary_condition["LeftValue"] + right_value = binary_condition["RightValue"] + try: + outcome = None + if binary_condition_type == ConditionTypeEnum.EQ.value: + if not isinstance(left_value, type(right_value)) and not isinstance( + right_value, type(left_value) + ): + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"LeftValue [{left_value}] of type [{type(left_value)}] and " + + f"RightValue [{right_value}] of type [{type(right_value)}] " + + "are not of the same type.", + ) + outcome = left_value == right_value + elif binary_condition_type == ConditionTypeEnum.GT.value: + outcome = left_value > right_value + elif binary_condition_type == ConditionTypeEnum.GTE.value: + outcome = left_value >= right_value + elif binary_condition_type == ConditionTypeEnum.LT.value: + outcome = left_value < right_value + elif binary_condition_type == ConditionTypeEnum.LTE.value: + outcome = left_value <= right_value + else: + raise NotImplementedError( + "Binary condition of type [%s] is not supported" % binary_condition_type + ) + return outcome + + except TypeError: + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"Condition of type [{binary_condition_type}] not supported between " + + f"[{left_value}] of type [{type(left_value)}] and [{right_value}] " + + f"of type [{type(right_value)}]", + ) + + def _resolve_not_condition(self, not_condition: dict): + """Resolve given ConditionNot. + + Args: + Dictionary representing given ConditionNot as request + + Returns: + True if given ConditionNot evaluated as true, + False otherwise. + """ + return not self._resolve_condition(not_condition["Expression"]) + + def _resolve_or_condition(self, or_condition: dict): + """Resolve given ConditionOr. + + Args: + Dictionary representing given ConditionOr as request + + Returns: + True if given ConditionOr evaluated as true, + False otherwise. + """ + + for condition in or_condition["Conditions"]: + if self._resolve_condition(condition): + return True + return False + + def _resolve_in_condition(self, in_condition: dict): + """Resolve given ConditionIn. + + Args: + Dictionary representing given ConditionIn as request + + Returns: + True if given ConditionIn evaluated as true, + False otherwise. + """ + + query_value = in_condition["QueryValue"] + values = in_condition["Values"] + return query_value in values + + +class _TransformStepExecutor(_StepExecutor): + """Executor class to execute TransformStep locally""" + + def execute(self): + job_name = unique_name_from_base(self.step.name) + step_arguments = self.pipline_executor.evaluate_step_arguments(self.step) + try: + self.pipline_executor.local_sagemaker_client.create_transform_job( + job_name, **step_arguments + ) + return self.pipline_executor.local_sagemaker_client.describe_transform_job(job_name) + except Exception as e: # pylint: disable=W0703 + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + ) + + +class _FailStepExecutor(_StepExecutor): + """Executor class to execute FailStep locally""" + + def execute(self): + step_arguments = self.pipline_executor.evaluate_step_arguments(self.step) + + error_message = step_arguments.get("ErrorMessage") + self.pipline_executor.execution.update_step_properties( + self.step.name, {"ErrorMessage": error_message} + ) + self.pipline_executor.execution.update_step_failure( + self.step.name, step_arguments.get("ErrorMessage") + ) + + +class _StepExecutorFactory: + """Factory class to generate executors for given step based on their types""" + + def __init__(self, pipeline_executor: LocalPipelineExecutor): + self.pipeline_executor = pipeline_executor + + def get(self, step: Step) -> _StepExecutor: + """Return corresponding step executor for given step""" + + step_type = step.step_type + step_executor = None + if step_type == StepTypeEnum.TRAINING: + step_executor = _TrainingStepExecutor(self.pipeline_executor, step) + elif step_type == StepTypeEnum.PROCESSING: + step_executor = _ProcessingStepExecutor(self.pipeline_executor, step) + elif step_type == StepTypeEnum.TRANSFORM: + step_executor = _TransformStepExecutor(self.pipeline_executor, step) + elif step_type == StepTypeEnum.FAIL: + step_executor = _FailStepExecutor(self.pipeline_executor, step) + elif step_type == StepTypeEnum.CONDITION: + step_executor = _ConditionStepExecutor(self.pipeline_executor, step) + else: + self.pipeline_executor.execution.update_step_failure( + step.name, f"Unsupported step type {step_type} to execute." + ) + return step_executor diff --git a/src/sagemaker/workflow/condition_step.py b/src/sagemaker/workflow/condition_step.py index 624c42fa66..a93b48d533 100644 --- a/src/sagemaker/workflow/condition_step.py +++ b/src/sagemaker/workflow/condition_step.py @@ -90,7 +90,7 @@ def arguments(self) -> RequestType: @property def step_only_arguments(self): """Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`.""" - return [condition.to_request() for condition in self.conditions] + return dict(Conditions=[condition.to_request() for condition in self.conditions]) @property def properties(self): diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 95bff2957b..2c4631e8f8 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -17,7 +17,7 @@ import logging from copy import deepcopy -from typing import Any, Dict, List, Sequence, Union, Optional +from typing import Any, Dict, List, Set, Sequence, Union, Optional import attr import botocore @@ -661,6 +661,20 @@ def is_cyclic_helper(current_step): return True return False + def get_steps_in_sub_dag(self, current_step: str, steps: Set[str] = None) -> Set[str]: + """Get names of all steps (including current step) in the sub dag of current step. + + Returns a set of step names in the sub dag. + """ + if steps is None: + steps = set() + if current_step not in self.adjacency_list: + raise ValueError("Step: %s does not exist in the pipeline." % current_step) + steps.add(current_step) + for step in self.adjacency_list[current_step]: + self.get_steps_in_sub_dag(step, steps) + return steps + def __iter__(self): """Perform topological sort traversal of the Pipeline Graph.""" diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 3e4f6dae3f..31a653292b 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -10,26 +10,66 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +from mock import Mock, PropertyMock, patch import pytest +from sagemaker.estimator import Estimator +from sagemaker.inputs import TrainingInput +from sagemaker.debugger import ProfilerConfig +from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor +from sagemaker.transformer import Transformer +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.conditions import ( + ConditionEquals, + ConditionGreaterThan, + ConditionGreaterThanOrEqualTo, + ConditionIn, + ConditionLessThan, + ConditionLessThanOrEqualTo, + ConditionNot, + ConditionOr, +) +from sagemaker.workflow.fail_step import FailStep from botocore.exceptions import ClientError -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline -from sagemaker.workflow.steps import CreateModelStep +from sagemaker.workflow.steps import CreateModelStep, StepTypeEnum from sagemaker.model import Model +from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.steps import ( + ProcessingStep, + TrainingStep, + TransformStep, + TransformInput, +) + +from sagemaker.local import LocalSession +from sagemaker.local.pipeline import ( + _ConditionStepExecutor, + _FailStepExecutor, + _ProcessingStepExecutor, + _StepExecutorFactory, + _TrainingStepExecutor, + _TransformStepExecutor, + LocalPipelineExecutor, + StepExecutionException, +) +from sagemaker.local.entities import _LocalExecutionStatus, _LocalPipelineExecution from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join -from sagemaker.local.local_session import LocalSession -from sagemaker.local.pipeline import LocalPipelineExecutor, StepExecutionException -from sagemaker.local.entities import _LocalPipelineExecution from tests.unit.sagemaker.workflow.helpers import CustomStep STRING_PARAMETER = ParameterString("MyStr", "DefaultParameter") INPUT_STEP = CustomStep(name="InputStep") +IMAGE_URI = "fakeimage" +ROLE = "DummyRole" +BUCKET = "my-bucket" +REGION = "us-west-2" +INSTANCE_TYPE = "ml.m4.xlarge" +INSTANCE_COUNT_PIPELINE_PARAMETER = ParameterInteger(name="InstanceCount", default_value=6) @pytest.fixture() @@ -42,6 +82,133 @@ def role_arn(): return "arn:role" +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + client_mock.describe_model.return_value = {"PrimaryContainer": {}, "Containers": {}} + return client_mock + + +@pytest.fixture +def pipeline_session(boto_session, client): + return PipelineSession( + boto_session=boto_session, + sagemaker_client=client, + default_bucket=BUCKET, + ) + + +@pytest.fixture +def training_step(pipeline_session): + estimator = Estimator( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=INSTANCE_COUNT_PIPELINE_PARAMETER, + instance_type="c4.4xlarge", + profiler_config=ProfilerConfig(system_monitor_interval_millis=500), + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + rules=[], + sagemaker_session=pipeline_session, + output_path="s3://a/b", + use_spot_instances=False, + ) + training_input = TrainingInput(s3_data=f"s3://{BUCKET}/train_manifest") + step_args = estimator.fit(inputs=training_input) + return TrainingStep( + name="MyTrainingStep", + description="TrainingStep description", + display_name="MyTrainingStep", + step_args=step_args, + ) + + +@pytest.fixture +def processing_step(pipeline_session): + processor = Processor( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=1, + instance_type=INSTANCE_TYPE, + sagemaker_session=pipeline_session, + ) + processing_input = [ + ProcessingInput( + source=f"s3://{BUCKET}/processing_manifest", + destination="processing_manifest", + ) + ] + processing_output = [ + ProcessingOutput( + output_name="output1", + source="/opt/ml/processing/output/output1", + destination="s3://some-bucket/some-path/output1", + s3_upload_mode="EndOfJob", + ) + ] + step_args = processor.run(inputs=processing_input, outputs=processing_output) + return ProcessingStep( + name="MyProcessingStep", + step_args=step_args, + description="ProcessingStep description", + display_name="MyProcessingStep", + ) + + +@pytest.fixture +def transform_step(pipeline_session): + transformer = Transformer( + model_name="my-model", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path="s3://my-bucket/my-output-path", + sagemaker_session=pipeline_session, + ) + transform_inputs = TransformInput(data="s3://my-bucket/my-data") + step_args = transformer.transform( + data=transform_inputs.data, + data_type=transform_inputs.data_type, + content_type=transform_inputs.content_type, + compression_type=transform_inputs.compression_type, + split_type=transform_inputs.split_type, + input_filter=transform_inputs.input_filter, + output_filter=transform_inputs.output_filter, + join_source=transform_inputs.join_source, + model_client_config=transform_inputs.model_client_config, + ) + return TransformStep( + name="MyTransformStep", + step_args=step_args, + ) + + def test_evaluate_parameter(local_sagemaker_session): step = CustomStep(name="MyStep", input_data=STRING_PARAMETER) pipeline = Pipeline( @@ -171,7 +338,10 @@ def test_evaluate_property_reference_undefined(local_sagemaker_session): (ExecutionVariables.PIPELINE_NAME, "blah-MyPipeline-blah"), (STRING_PARAMETER, "blah-DefaultParameter-blah"), (INPUT_STEP.properties.TrainingJobArn, "blah-my-training-arn-blah"), - (Join(on=".", values=["test1", "test2", "test3"]), "blah-test1.test2.test3-blah"), + ( + Join(on=".", values=["test1", "test2", "test3"]), + "blah-test1.test2.test3-blah", + ), ( Join(on=".", values=["test", ExecutionVariables.PIPELINE_NAME, "test"]), "blah-test.MyPipeline.test-blah", @@ -211,3 +381,539 @@ def test_execute_unsupported_step_type(role_arn, local_sagemaker_session): with pytest.raises(ClientError) as e: pipeline.start() assert f"Step type {step.step_type.value} is not supported in local mode." in str(e.value) + + +@pytest.mark.parametrize( + "step, step_executor_class", + [ + (Mock(step_type=StepTypeEnum.TRAINING), _TrainingStepExecutor), + (Mock(step_type=StepTypeEnum.PROCESSING), _ProcessingStepExecutor), + (Mock(step_type=StepTypeEnum.TRANSFORM), _TransformStepExecutor), + (Mock(step_type=StepTypeEnum.CONDITION), _ConditionStepExecutor), + (Mock(step_type=StepTypeEnum.FAIL), _FailStepExecutor), + ], +) +def test_step_executor_factory(step, step_executor_class): + local_pipeline_executor = Mock() + step_executor_factory = _StepExecutorFactory(local_pipeline_executor) + step_executor = step_executor_factory.get(step) + assert isinstance(step_executor, step_executor_class) + + +@patch( + "sagemaker.local.image._SageMakerContainer.train", + return_value="/some/path/to/model", +) +def test_execute_pipeline_training_step(train, local_sagemaker_session, training_step): + pipeline = Pipeline( + name="MyPipeline1", + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + steps=[training_step], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-1", pipeline), local_sagemaker_session + ).execute() + assert execution.status == _LocalExecutionStatus.SUCCEEDED.value + assert execution.pipeline_execution_name == "my-execution-1" + + step_execution = execution.step_execution + expected_must_have = { + "ResourceConfig": {"InstanceCount": 6}, + "TrainingJobStatus": "Completed", + "ModelArtifacts": {"S3ModelArtifacts": "/some/path/to/model"}, + } + assert step_execution["MyTrainingStep"].status == "Succeeded" + assert expected_must_have.items() <= step_execution["MyTrainingStep"].properties.items() + + +@patch("sagemaker.local.image._SageMakerContainer.process") +def test_execute_pipeline_processing_step(process, local_sagemaker_session, processing_step): + pipeline = Pipeline( + name="MyPipeline2", + steps=[processing_step], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-2", pipeline), local_sagemaker_session + ).execute() + assert execution.status == _LocalExecutionStatus.SUCCEEDED.value + assert execution.pipeline_execution_name == "my-execution-2" + + step_execution = execution.step_execution + step_properties = step_execution["MyProcessingStep"].properties + assert step_execution["MyProcessingStep"].status == "Succeeded" + assert "MyProcessingStep-" in step_properties["ProcessingJobArn"] + assert "MyProcessingStep-" in step_properties["ProcessingJobName"] + assert step_properties["AppSpecification"]["ImageUri"] == IMAGE_URI + s3_input = step_properties["ProcessingInputs"]["input-1"][ + "S3Input" + ] # input name "input-1" is auto-generated + assert s3_input["S3Uri"] == f"s3://{BUCKET}/processing_manifest" + assert s3_input["LocalPath"] == "processing_manifest" + cluster_config = step_properties["ProcessingResources"]["ClusterConfig"] + assert cluster_config["InstanceCount"] == 1 + assert cluster_config["InstanceType"] == INSTANCE_TYPE + assert step_properties["ProcessingJobStatus"] == "Completed" + expected_processing_output = { + "OutputName": "output1", + "AppManaged": False, + "S3Output": { + "S3Uri": "s3://some-bucket/some-path/output1", + "LocalPath": "/opt/ml/processing/output/output1", + "S3UploadMode": "EndOfJob", + }, + } + processing_output = step_properties["ProcessingOutputConfig"]["Outputs"]["output1"] + assert processing_output == expected_processing_output + + +@patch("sagemaker.local.local_session._LocalTransformJob") +def test_execute_pipeline_transform_step( + _LocalTransformJob, local_sagemaker_session, transform_step +): + pipeline = Pipeline( + name="MyPipeline3", + steps=[transform_step], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-3", pipeline), local_sagemaker_session + ).execute() + + _LocalTransformJob().start.assert_called_with( + { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://my-bucket/my-data", + } + } + }, + {"S3OutputPath": "s3://my-bucket/my-output-path"}, + {"InstanceCount": 1, "InstanceType": "ml.m5.xlarge"}, + ) + + _LocalTransformJob().describe.assert_called() + + assert execution.status == _LocalExecutionStatus.SUCCEEDED.value + assert execution.pipeline_execution_name == "my-execution-3" + + step_execution = execution.step_execution + assert step_execution["MyTransformStep"].status == _LocalExecutionStatus.SUCCEEDED.value + + +def test_execute_pipeline_fail_step(local_sagemaker_session): + param = ParameterString(name="foo", default_value="bar") + step_fail = FailStep( + name="FailStep", + error_message=Join(on=": ", values=["Failed due to foo has value", param]), + ) + pipeline = Pipeline( + name="MyPipeline4", + steps=[step_fail], + parameters=[param], + sagemaker_session=local_sagemaker_session, + ) + + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-4", pipeline), local_sagemaker_session + ).execute() + + assert execution.status == _LocalExecutionStatus.FAILED.value + assert execution.pipeline_execution_name == "my-execution-4" + + fail_step_execution = execution.step_execution.get(step_fail.name) + assert fail_step_execution.status == _LocalExecutionStatus.FAILED.value + assert fail_step_execution.properties == {"ErrorMessage": "Failed due to foo has value: bar"} + assert fail_step_execution.failure_reason == "Failed due to foo has value: bar" + + +@pytest.mark.parametrize( + "condition, condition_outcome, succeeded_steps, executing_steps", + [ + ( + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=1), + False, + ["MyProcessingStep"], + ["MyTrainingStep"], + ), + ( + ConditionGreaterThan(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=1), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionGreaterThanOrEqualTo(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=6), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionLessThan(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=1), + False, + ["MyProcessingStep"], + ["MyTrainingStep"], + ), + ( + ConditionLessThanOrEqualTo(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=6), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionIn(value=INSTANCE_COUNT_PIPELINE_PARAMETER, in_values=[3, 6, 9]), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionNot(ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=1)), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionOr( + conditions=[ + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=3), + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=6), + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=9), + ] + ), + True, + ["MyTrainingStep"], + ["MyProcessingStep"], + ), + ( + ConditionOr( + conditions=[ + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=3), + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=7), + ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right=9), + ] + ), + False, + ["MyProcessingStep"], + ["MyTrainingStep"], + ), + ], +) +@patch( + "sagemaker.local.image._SageMakerContainer.train", + return_value="/some/path/to/model", +) +@patch("sagemaker.local.image._SageMakerContainer.process") +def test_execute_pipeline_condition_step_test_conditions( + process, + train, + local_sagemaker_session, + training_step, + processing_step, + condition, + condition_outcome, + succeeded_steps, + executing_steps, +): + condition_step = ConditionStep( + name="MyCondStep", + conditions=[condition], + if_steps=[training_step], + else_steps=[processing_step], + ) + pipeline = Pipeline( + name="MyPipeline5", + steps=[condition_step], + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + sagemaker_session=local_sagemaker_session, + ) + + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-5", pipeline), local_sagemaker_session + ).execute() + + assert execution.status == _LocalExecutionStatus.SUCCEEDED.value + assert ( + execution.step_execution.get("MyCondStep").status == _LocalExecutionStatus.SUCCEEDED.value + ) + assert execution.step_execution.get("MyCondStep").properties == {"Outcome": condition_outcome} + + for succeeded_step in succeeded_steps: + assert ( + execution.step_execution.get(succeeded_step).status + == _LocalExecutionStatus.SUCCEEDED.value + ) + assert execution.step_execution.get(succeeded_step).properties != {} + assert execution.step_execution.get(succeeded_step).failure_reason is None + + for executing_step in executing_steps: + assert ( + execution.step_execution.get(executing_step).status + == _LocalExecutionStatus.STARTING.value + ) + assert execution.step_execution.get(executing_step).properties == {} + assert execution.step_execution.get(executing_step).failure_reason is None + + +# ┌──►F +# │ +# A──►B──►C──►E──►G──►H +# │ ▲ +# └──►D──►I───┘ +@pytest.mark.parametrize( + "left_value_1, left_value_2, expected_path", + [ + (2, 2, ["stepA", "stepB", "stepC", "stepE"]), + (2, 1, ["stepA", "stepB", "stepC", "stepF"]), + (1, 2, ["stepA", "stepB", "stepD", "stepI"]), + (1, 1, ["stepA", "stepB", "stepD", "stepI"]), + ], +) +@patch( + "sagemaker.local.local_session.LocalSagemakerClient.describe_training_job", + return_value={}, +) +@patch("sagemaker.local.local_session.LocalSagemakerClient.create_training_job") +def test_pipeline_execution_condition_step_execution_path( + create_training_job, + describe_training_job, + local_sagemaker_session, + left_value_1, + left_value_2, + expected_path, +): + condition_1 = ConditionEquals(left=left_value_1, right=2) + condition_2 = ConditionEquals(left=left_value_2, right=2) + step_a = CustomStep(name="stepA") + step_e = CustomStep(name="stepE") + step_f = CustomStep(name="stepF") + step_d = CustomStep(name="stepD") + step_i = CustomStep(name="stepI", depends_on=[step_d.name]) + step_c = ConditionStep( + name="stepC", + conditions=[condition_2], + if_steps=[step_e], + else_steps=[step_f], + ) + step_b = ConditionStep( + name="stepB", + depends_on=[step_a.name], + conditions=[condition_1], + if_steps=[step_c], + else_steps=[step_d], + ) + step_g = CustomStep(name="stepG", depends_on=[step_e.name, step_i.name]) + step_h = CustomStep(name="stepH", depends_on=[step_g.name]) + + pipeline = Pipeline( + name="MyPipeline5-1", + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + steps=[step_a, step_b, step_g, step_h, step_i], + sagemaker_session=local_sagemaker_session, + ) + + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-5-1", pipeline), local_sagemaker_session + ).execute() + + actual_path = [] + for step_name, step_execution in execution.step_execution.items(): + if step_execution.status != _LocalExecutionStatus.STARTING.value: + actual_path.append(step_name) + assert actual_path == expected_path + + +def test_condition_step_incompatible_types(local_sagemaker_session): + + step_a = CustomStep(name="stepA") + step_b = CustomStep(name="stepB") + step_cond = ConditionStep( + name="stepCondition", + conditions=[ConditionEquals(left=INSTANCE_COUNT_PIPELINE_PARAMETER, right="some_string")], + if_steps=[step_a], + else_steps=[step_b], + ) + + pipeline = Pipeline( + name="MyPipeline5-2", + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + steps=[step_cond], + sagemaker_session=local_sagemaker_session, + ) + + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-5-1", pipeline), local_sagemaker_session + ).execute() + + assert execution.status == _LocalExecutionStatus.FAILED.value + assert ( + "LeftValue [6] of type [] and RightValue [some_string] of " + + "type [] are not of the same type." + in execution.failure_reason + ) + assert execution.step_execution["stepA"].status == _LocalExecutionStatus.STARTING.value + assert execution.step_execution["stepB"].status == _LocalExecutionStatus.STARTING.value + execution.step_execution["stepCondition"].status == _LocalExecutionStatus.FAILED.value + + +@patch("sagemaker.local.local_session._LocalTrainingJob") +@patch("sagemaker.local.image._SageMakerContainer.process") +def test_processing_and_training_steps_with_data_dependency( + process, + _LocalTrainingJob, + pipeline_session, + local_sagemaker_session, + processing_step, +): + + estimator = Estimator( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=INSTANCE_COUNT_PIPELINE_PARAMETER, + instance_type="c4.4xlarge", + profiler_config=ProfilerConfig(system_monitor_interval_millis=500), + hyperparameters={ + "batch-size": 500, + "epochs": 5, + }, + rules=[], + sagemaker_session=pipeline_session, + output_path="s3://a/b", + use_spot_instances=False, + ) + training_input = TrainingInput( + s3_data=processing_step.properties.ProcessingOutputConfig.Outputs["output1"].S3Output.S3Uri + ) + step_args = estimator.fit(inputs=training_input) + training_step = TrainingStep( + name="MyTrainingStep", + description="TrainingStep description", + display_name="MyTrainingStep", + step_args=step_args, + ) + + pipeline = Pipeline( + name="MyPipeline6", + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + steps=[processing_step, training_step], + sagemaker_session=local_sagemaker_session, + ) + + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-6", pipeline), local_sagemaker_session + ).execute() + + args_called_with = _LocalTrainingJob().start.call_args.args + + # input_data_config + assert args_called_with[0] == [ + { + "DataSource": { + "S3DataSource": { + "S3DataType": "S3Prefix", + "S3Uri": "s3://some-bucket/some-path/output1", # from depended processing step + "S3DataDistributionType": "FullyReplicated", + } + }, + "ChannelName": "training", + } + ] + + # output_data_config + assert args_called_with[1] == {"S3OutputPath": "s3://a/b"} + + # hyperparameters + assert args_called_with[2] == {"batch-size": "500", "epochs": "5"} + + # environment + assert args_called_with[3] == {} + + # job_name + assert args_called_with[4].startswith("MyTrainingStep-") + + assert ( + execution.step_execution.get("MyProcessingStep").status + == _LocalExecutionStatus.SUCCEEDED.value + ) + assert ( + execution.step_execution.get("MyTrainingStep").status + == _LocalExecutionStatus.SUCCEEDED.value + ) + assert execution.status == _LocalExecutionStatus.SUCCEEDED.value + + +@patch( + "sagemaker.local.local_session.LocalSagemakerClient.create_training_job", + side_effect=RuntimeError("Dummy RuntimeError"), +) +def test_execute_pipeline_step_create_training_job_fail( + create_training_job, local_sagemaker_session, pipeline_session, training_step +): + pipeline = Pipeline( + name="MyPipelineX-" + training_step.name, + steps=[training_step], + parameters=[INSTANCE_COUNT_PIPELINE_PARAMETER], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-x-" + training_step.name, pipeline), + local_sagemaker_session, + ).execute() + + assert execution.status == _LocalExecutionStatus.FAILED.value + assert execution.pipeline_execution_name == "my-execution-x-" + training_step.name + + step_execution = execution.step_execution + assert step_execution[training_step.name].status == _LocalExecutionStatus.FAILED.value + assert "Dummy RuntimeError" in step_execution[training_step.name].failure_reason + + +@patch( + "sagemaker.local.local_session.LocalSagemakerClient.create_processing_job", + side_effect=RuntimeError("Dummy RuntimeError"), +) +def test_execute_pipeline_step_create_processing_job_fail( + create_processing_job, local_sagemaker_session, pipeline_session, processing_step +): + pipeline = Pipeline( + name="MyPipelineX-" + processing_step.name, + steps=[processing_step], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-x-" + processing_step.name, pipeline), + local_sagemaker_session, + ).execute() + + assert execution.status == _LocalExecutionStatus.FAILED.value + assert execution.pipeline_execution_name == "my-execution-x-" + processing_step.name + + step_execution = execution.step_execution + assert step_execution[processing_step.name].status == _LocalExecutionStatus.FAILED.value + assert "Dummy RuntimeError" in step_execution[processing_step.name].failure_reason + + +@patch( + "sagemaker.local.local_session.LocalSagemakerClient.create_transform_job", + side_effect=RuntimeError("Dummy RuntimeError"), +) +def test_execute_pipeline_step_create_transform_job_fail( + create_transform_job, local_sagemaker_session, pipeline_session, transform_step +): + pipeline = Pipeline( + name="MyPipelineX-" + transform_step.name, + steps=[transform_step], + sagemaker_session=local_sagemaker_session, + ) + execution = LocalPipelineExecutor( + _LocalPipelineExecution("my-execution-x-" + transform_step.name, pipeline), + local_sagemaker_session, + ).execute() + + assert execution.status == _LocalExecutionStatus.FAILED.value + assert execution.pipeline_execution_name == "my-execution-x-" + transform_step.name + + step_execution = execution.step_execution + assert step_execution[transform_step.name].status == _LocalExecutionStatus.FAILED.value + assert "Dummy RuntimeError" in step_execution[transform_step.name].failure_reason diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py index 2450adfe8a..c0c27eebee 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_graph.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -27,7 +27,11 @@ ) from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.parameters import ParameterInteger, ParameterString -from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep, CustomStepCollection +from tests.unit.sagemaker.workflow.helpers import ( + ordered, + CustomStep, + CustomStepCollection, +) @pytest.fixture @@ -42,6 +46,44 @@ def role_arn(): return "arn:role" +# ┌──►F +# │ +# A──►B──►C──►E──►G──►H──►I +# │ ▲ │ +# └──►D───────┘ └──►J +@pytest.fixture +def pipeline_graph_get_sub_dag(sagemaker_session_mock): + step_a = CustomStep(name="stepA") + step_b = CustomStep(name="stepB", depends_on=[step_a]) + step_c = CustomStep(name="stepC", depends_on=[step_b]) + step_d = CustomStep(name="stepD", depends_on=[step_b]) + step_e = CustomStep(name="stepE", depends_on=[step_c]) + step_f = CustomStep(name="stepF", depends_on=[step_c]) + step_g = CustomStep(name="stepG", depends_on=[step_e, step_d]) + step_h = CustomStep(name="stepH", depends_on=[step_g]) + step_i = CustomStep(name="stepI", depends_on=[step_h]) + step_j = CustomStep(name="stepJ", depends_on=[step_h]) + + pipeline = Pipeline( + name="MyPipeline", + steps=[ + step_a, + step_b, + step_c, + step_d, + step_e, + step_f, + step_g, + step_h, + step_i, + step_j, + ], + sagemaker_session=sagemaker_session_mock, + ) + + return PipelineGraph.from_pipeline(pipeline) + + def test_pipeline_duplicate_step_name(sagemaker_session_mock): step1 = CustomStep(name="foo") step2 = CustomStep(name="foo") @@ -59,7 +101,10 @@ def test_pipeline_duplicate_step_name_in_condition_step(sagemaker_session_mock): custom_step = CustomStep(name="foo") custom_step2 = CustomStep(name="foo") condition_step = ConditionStep( - name="condStep", conditions=[cond], depends_on=[custom_step], if_steps=[custom_step2] + name="condStep", + conditions=[cond], + depends_on=[custom_step], + if_steps=[custom_step2], ) with pytest.raises(ValueError) as error: pipeline = Pipeline( @@ -145,7 +190,12 @@ def test_pipeline_graph_with_condition_step_explicit_dependency(sagemaker_sessio pipeline_graph = PipelineGraph.from_pipeline(pipeline) adjacency_list = pipeline_graph.adjacency_list assert ordered(adjacency_list) == ordered( - {"condStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": [], "TestStep": ["condStep"]} + { + "condStep": ["ElseStep", "IfStep"], + "ElseStep": [], + "IfStep": [], + "TestStep": ["condStep"], + } ) _verify_pipeline_graph_traversal(pipeline_graph) @@ -169,12 +219,19 @@ def test_pipeline_graph_with_condition_step_property_reference_dependency( pipeline_graph = PipelineGraph.from_pipeline(pipeline) adjacency_list = pipeline_graph.adjacency_list assert ordered(adjacency_list) == ordered( - {"condStep": ["ElseStep", "IfStep"], "ElseStep": [], "IfStep": [], "TestStep": ["condStep"]} + { + "condStep": ["ElseStep", "IfStep"], + "ElseStep": [], + "IfStep": [], + "TestStep": ["condStep"], + } ) _verify_pipeline_graph_traversal(pipeline_graph) -def test_pipeline_graph_with_step_collection_explicit_dependency(sagemaker_session_mock): +def test_pipeline_graph_with_step_collection_explicit_dependency( + sagemaker_session_mock, +): custom_step1 = CustomStep(name="TestStep") custom_step_collection = CustomStepCollection( name="TestStepCollection", depends_on=[custom_step1] @@ -231,7 +288,9 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock): step_c = CustomStep(name="stepC", depends_on=["stepB"]) pipeline = Pipeline( - name="MyPipeline", steps=[step_a, step_b, step_c], sagemaker_session=sagemaker_session_mock + name="MyPipeline", + steps=[step_a, step_b, step_c], + sagemaker_session=sagemaker_session_mock, ) with pytest.raises(ValueError) as error: @@ -239,6 +298,53 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock): assert "Cycle detected in pipeline step graph." in str(error.value) +@pytest.mark.parametrize( + "step_name, expected_steps", + [ + ( + "stepA", + { + "stepA", + "stepB", + "stepC", + "stepD", + "stepE", + "stepF", + "stepG", + "stepH", + "stepI", + "stepJ", + }, + ), + ( + "stepB", + { + "stepB", + "stepC", + "stepD", + "stepE", + "stepF", + "stepG", + "stepH", + "stepI", + "stepJ", + }, + ), + ("stepC", {"stepC", "stepE", "stepF", "stepG", "stepH", "stepI", "stepJ"}), + ("stepD", {"stepD", "stepG", "stepH", "stepI", "stepJ"}), + ("stepE", {"stepE", "stepG", "stepH", "stepI", "stepJ"}), + ("stepF", {"stepF"}), + ("stepG", {"stepG", "stepH", "stepI", "stepJ"}), + ("stepH", {"stepH", "stepI", "stepJ"}), + ("stepI", {"stepI"}), + ("stepJ", {"stepJ"}), + ], +) +def test_get_steps_in_sub_dag(pipeline_graph_get_sub_dag, step_name, expected_steps): + sub_steps = pipeline_graph_get_sub_dag.get_steps_in_sub_dag(step_name) + assert sub_steps == expected_steps + + def test_condition_comparison(sagemaker_session): param = ParameterInteger(name="MyInt") cond = ConditionEquals(left=param, right=1) From e6bd1fe6ae7f01a2d098682de0a2bc1f1cd38022 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Mon, 18 Jul 2022 10:56:29 -0700 Subject: [PATCH 170/526] change: implement local JsonGet function Co-authored-by: Namrata Madan --- src/sagemaker/local/entities.py | 115 ++++-- src/sagemaker/local/local_session.py | 11 +- src/sagemaker/local/pipeline.py | 61 ++- src/sagemaker/local/utils.py | 33 +- src/sagemaker/workflow/steps.py | 55 ++- .../sagemaker/local/test_local_entities.py | 107 +++++ .../sagemaker/local/test_local_pipeline.py | 391 +++++++++++++----- .../sagemaker/local/test_local_session.py | 75 ++++ .../unit/sagemaker/local/test_local_utils.py | 49 +++ tests/unit/sagemaker/workflow/helpers.py | 16 +- .../unit/sagemaker/workflow/test_pipeline.py | 2 +- tests/unit/sagemaker/workflow/test_steps.py | 108 ++++- 12 files changed, 851 insertions(+), 172 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 3d9ddbf77e..cfc0307ff6 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -625,7 +625,7 @@ def describe(self): class _LocalPipeline(object): - """Placeholder docstring""" + """Class representing a local SageMaker Pipeline""" _executions = {} @@ -645,7 +645,7 @@ def __init__( self.last_modified_time = now_time def describe(self): - """Placeholder docstring""" + """Describe Pipeline""" response = { "PipelineArn": self.pipeline.name, "PipelineDefinition": self.pipeline.definition(), @@ -659,7 +659,7 @@ def describe(self): return response def start(self, **kwargs): - """Placeholder docstring""" + """Start a pipeline execution. Returns a _LocalPipelineExecution object.""" from sagemaker.local.pipeline import LocalPipelineExecutor execution_id = str(uuid4()) @@ -670,7 +670,7 @@ def start(self, **kwargs): class _LocalPipelineExecution(object): - """Placeholder docstring""" + """Class representing a local SageMaker pipeline execution.""" def __init__( self, @@ -693,7 +693,7 @@ def __init__( self.blockout_steps = {} def describe(self): - """Placeholder docstring""" + """Describe Pipeline Execution.""" response = { "CreationTime": self.creation_time, "LastModifiedTime": self.creation_time, @@ -708,8 +708,14 @@ def describe(self): return filtered_response def list_steps(self): - """Placeholder docstring""" - # TODO + """List pipeline execution steps.""" + return { + "PipelineExecutionSteps": [ + step.to_list_steps_response() + for step in self.step_execution.values() + if step.status is not None + ] + } def update_execution_success(self): """Mark execution as succeeded.""" @@ -730,8 +736,8 @@ def update_step_failure(self, step_name, failure_message): self.step_execution.get(step_name).update_step_failure(failure_message) def mark_step_executing(self, step_name): - """Update step's status to EXECUTING""" - self.step_execution.get(step_name).status = _LocalExecutionStatus.EXECUTING.value + """Update pipelines step's status to EXECUTING and start_time to now.""" + self.step_execution.get(step_name).mark_step_executing() def _initialize_step_execution(self, steps): """Initialize step_execution dict.""" @@ -751,7 +757,9 @@ def _initialize_step_execution(self, steps): "Step type {} is not supported in local mode.".format(step.step_type.value) ) raise ClientError(error_msg, "start_pipeline_execution") - self.step_execution[step.name] = _LocalPipelineStepExecution(step.name, step.step_type) + self.step_execution[step.name] = _LocalPipelineExecutionStep( + step.name, step.step_type, step.description, step.display_name + ) if step.step_type == StepTypeEnum.CONDITION: self._initialize_step_execution(step.if_steps + step.else_steps) @@ -790,44 +798,105 @@ def _construct_validation_exception_message(exception_msg): return {"Error": {"Code": "ValidationException", "Message": exception_msg}} -class _LocalPipelineStepExecution(object): - """Placeholder docstring""" +class _LocalPipelineExecutionStep(object): + """Class representing a local pipeline execution step.""" def __init__( self, - step_name, + name, step_type, - last_modified_time=None, + description, + display_name=None, + start_time=None, + end_time=None, status=None, properties=None, failure_reason=None, ): - self.step_name = step_name - self.step_type = step_type - self.status = status or _LocalExecutionStatus.STARTING.value + from sagemaker.workflow.steps import StepTypeEnum + + self.name = name + self.type = step_type + self.description = description + self.display_name = display_name + self.status = status self.failure_reason = failure_reason self.properties = properties or {} - self.creation_time = datetime.datetime.now() - self.last_modified_time = last_modified_time or self.creation_time + self.start_time = start_time + self.end_time = end_time + self._step_type_to_output_format_map = { + StepTypeEnum.TRAINING: self._construct_training_metadata, + StepTypeEnum.PROCESSING: self._construct_processing_metadata, + StepTypeEnum.TRANSFORM: self._construct_transform_metadata, + StepTypeEnum.CONDITION: self._construct_condition_metadata, + StepTypeEnum.FAIL: self._construct_fail_metadata, + } def update_step_properties(self, properties): """Update pipeline step execution output properties.""" - logger.info("Successfully completed step %s.", self.step_name) + logger.info("Successfully completed step %s.", self.name) self.properties = deepcopy(properties) self.status = _LocalExecutionStatus.SUCCEEDED.value + self.end_time = datetime.datetime.now() def update_step_failure(self, failure_message): """Update pipeline step execution failure status and message.""" logger.error(failure_message) self.failure_reason = failure_message self.status = _LocalExecutionStatus.FAILED.value - raise StepExecutionException(self.step_name, failure_message) + self.end_time = datetime.datetime.now() + raise StepExecutionException(self.name, failure_message) + + def mark_step_executing(self): + """Update pipelines step's status to EXECUTING and start_time to now""" + self.status = _LocalExecutionStatus.EXECUTING.value + self.start_time = datetime.datetime.now() + + def to_list_steps_response(self): + """Convert to response dict for list_steps calls.""" + response = { + "EndTime": self.end_time, + "FailureReason": self.failure_reason, + "Metadata": self._construct_metadata(), + "StartTime": self.start_time, + "StepDescription": self.description, + "StepDisplayName": self.display_name, + "StepName": self.name, + "StepStatus": self.status, + } + filtered_response = {k: v for k, v in response.items() if v is not None} + return filtered_response + + def _construct_metadata(self): + """Constructs the metadata shape for the list_steps_response.""" + if self.properties: + return self._step_type_to_output_format_map[self.type]() + return None + + def _construct_training_metadata(self): + """Construct training job metadata response.""" + return {"TrainingJob": {"Arn": self.properties.TrainingJobArn}} + + def _construct_processing_metadata(self): + """Construct processing job metadata response.""" + return {"ProcessingJob": {"Arn": self.properties.ProcessingJobArn}} + + def _construct_transform_metadata(self): + """Construct transform job metadata response.""" + return {"TransformJob": {"Arn": self.properties.TransformJobArn}} + + def _construct_condition_metadata(self): + """Construct condition step metadata response.""" + return {"Condition": {"Outcome": self.properties.Outcome}} + + def _construct_fail_metadata(self): + """Construct fail step metadata response.""" + return {"Fail": {"ErrorMessage": self.properties.ErrorMessage}} class _LocalExecutionStatus(enum.Enum): - """Placeholder docstring""" + """Pipeline execution status.""" - STARTING = "Starting" EXECUTING = "Executing" SUCCEEDED = "Succeeded" FAILED = "Failed" diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index c9f6c910bd..8d48bfb5ba 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -595,7 +595,9 @@ class LocalSession(Session): :class:`~sagemaker.session.Session`. """ - def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False): + def __init__( + self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False + ): """Create a Local SageMaker Session. Args: @@ -614,7 +616,7 @@ def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=F # discourage external use: self._disable_local_code = disable_local_code - super(LocalSession, self).__init__(boto_session) + super(LocalSession, self).__init__(boto_session=boto_session, default_bucket=default_bucket) if platform.system() == "Windows": logger.warning("Windows Support for Local Mode is Experimental") @@ -718,9 +720,12 @@ def __init__(self, fileUri, content_type=None): class LocalPipelineSession(LocalSession): """Class representing a local session for SageMaker Pipelines executions.""" - def __init__(self, boto_session=None, s3_endpoint_url=None, disable_local_code=False): + def __init__( + self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False + ): super().__init__( boto_session=boto_session, + default_bucket=default_bucket, s3_endpoint_url=s3_endpoint_url, disable_local_code=disable_local_code, ) diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index 1f8c3d86f9..7da6e83165 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -15,15 +15,17 @@ from abc import ABC, abstractmethod import logging +import json from copy import deepcopy from datetime import datetime from typing import Dict, List -from sagemaker.workflow.conditions import ConditionTypeEnum +from botocore.exceptions import ClientError +from sagemaker.workflow.conditions import ConditionTypeEnum from sagemaker.workflow.steps import StepTypeEnum, Step from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import Parameter -from sagemaker.workflow.functions import Join, JsonGet +from sagemaker.workflow.functions import Join, JsonGet, PropertyFile from sagemaker.workflow.properties import Properties from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables from sagemaker.workflow.pipeline import PipelineGraph @@ -116,8 +118,7 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name): elif isinstance(pipeline_variable, ExecutionVariable): value = self._evaluate_execution_variable(pipeline_variable) elif isinstance(pipeline_variable, JsonGet): - # TODO - raise NotImplementedError + value = self._evaluate_json_get_function(pipeline_variable, step_name) else: self.execution.update_step_failure( step_name, f"Unrecognized pipeline variable {pipeline_variable.expr}." @@ -133,7 +134,7 @@ def _evaluate_property_reference(self, pipeline_variable, step_name): referenced_step_name = pipeline_variable.step_name step_properties = self.execution.step_execution.get(referenced_step_name).properties return get_using_dot_notation(step_properties, pipeline_variable.path) - except (KeyError, IndexError): + except (KeyError, IndexError, TypeError): self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.") def _evaluate_execution_variable(self, pipeline_variable): @@ -154,6 +155,56 @@ def _evaluate_execution_variable(self, pipeline_variable): return datetime.now() return None + def _evaluate_json_get_function(self, pipeline_variable, step_name): + """Evaluate join function runtime value.""" + property_file_reference = pipeline_variable.property_file + property_file = None + if isinstance(property_file_reference, str): + processing_step = self.pipeline_dag.step_map[pipeline_variable.step_name] + for file in processing_step.property_files: + if file.name == property_file_reference: + property_file = file + break + elif isinstance(property_file_reference, PropertyFile): + property_file = property_file_reference + processing_step_response = self.execution.step_execution.get( + pipeline_variable.step_name + ).properties + if ( + "ProcessingOutputConfig" not in processing_step_response + or "Outputs" not in processing_step_response["ProcessingOutputConfig"] + ): + self.execution.update_step_failure( + step_name, + f"Step '{pipeline_variable.step_name}' does not yet contain processing outputs.", + ) + processing_output_s3_bucket = None + for output in processing_step_response["ProcessingOutputConfig"]["Outputs"]: + if output["OutputName"] == property_file.output_name: + processing_output_s3_bucket = output["S3Output"]["S3Uri"] + break + try: + file_content = self.sagemaker_session.read_s3_file( + processing_output_s3_bucket, property_file.path + ) + file_json = json.loads(file_content) + return get_using_dot_notation(file_json, pipeline_variable.json_path) + except ClientError as e: + self.execution.update_step_failure( + step_name, + f"Received an error while file reading file '{property_file.path}' from S3: " + f"{e.response.get('Code')}: {e.response.get('Message')}", + ) + except json.JSONDecodeError: + self.execution.update_step_failure( + step_name, + f"Contents of property file '{property_file.name}' are not in valid JSON format.", + ) + except (KeyError, IndexError, TypeError): + self.execution.update_step_failure( + step_name, f"Invalid json path '{pipeline_variable.json_path}'" + ) + class _StepExecutor(ABC): """An abstract base class for step executors running steps locally""" diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index 3031a407fe..cd0c45b2ea 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -166,7 +166,7 @@ def get_using_dot_notation(dictionary, keys): Nested object within dictionary as defined by "keys" Raises: - KeyError or IndexError if the provided key does not exist in input dictionary + KeyError/IndexError/TypeError if the provided key does not exist in input dictionary """ if keys is None: return dictionary @@ -175,14 +175,23 @@ def get_using_dot_notation(dictionary, keys): rest = None if len(split_keys) > 1: rest = split_keys[1] - list_accessor = re.search(r"(\w+)\[(\d+)]", key) - if list_accessor: - key = list_accessor.group(1) - list_index = int(list_accessor.group(2)) - return get_using_dot_notation(dictionary[key][list_index], rest) - dict_accessor = re.search(r"(\w+)\[['\"](\S+)['\"]]", key) - if dict_accessor: - key = dict_accessor.group(1) - inner_key = dict_accessor.group(2) - return get_using_dot_notation(dictionary[key][inner_key], rest) - return get_using_dot_notation(dictionary[key], rest) + bracket_accessors = re.findall(r"\[(.+?)]", key) + if bracket_accessors: + pre_bracket_key = key.split("[", 1)[0] + inner_dict = dictionary[pre_bracket_key] + else: + inner_dict = dictionary[key] + for bracket_accessor in bracket_accessors: + if ( + bracket_accessor.startswith("'") + and bracket_accessor.endswith("'") + or bracket_accessor.startswith('"') + and bracket_accessor.endswith('"') + ): + # key accessor + inner_key = bracket_accessor[1:-1] + else: + # list accessor + inner_key = int(bracket_accessor) + inner_dict = inner_dict[inner_key] + return get_using_dot_notation(inner_dict, rest) diff --git a/src/sagemaker/workflow/steps.py b/src/sagemaker/workflow/steps.py index e04d51c946..d54320b3cc 100644 --- a/src/sagemaker/workflow/steps.py +++ b/src/sagemaker/workflow/steps.py @@ -48,7 +48,7 @@ Properties, ) from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.functions import Join +from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.retry import RetryPolicy if TYPE_CHECKING: @@ -192,9 +192,8 @@ def _find_dependencies_in_depends_on_list( dependencies.add(self._get_step_name_from_str(step, step_map)) return dependencies - @staticmethod def _find_dependencies_in_step_arguments( - obj: Any, step_map: Dict[str, Union["Step", "StepCollection"]] + self, obj: Any, step_map: Dict[str, Union["Step", "StepCollection"]] ): """Find the step dependencies referenced in the arguments of this step.""" dependencies = set() @@ -202,16 +201,56 @@ def _find_dependencies_in_step_arguments( for value in obj.values(): if isinstance(value, (PipelineVariable, Condition)): for referenced_step in value._referenced_steps: - dependencies.add(Step._get_step_name_from_str(referenced_step, step_map)) - dependencies.update(Step._find_dependencies_in_step_arguments(value, step_map)) + dependencies.add(self._get_step_name_from_str(referenced_step, step_map)) + if isinstance(value, JsonGet): + self._validate_json_get_function(value, step_map) + dependencies.update(self._find_dependencies_in_step_arguments(value, step_map)) elif isinstance(obj, list): for item in obj: if isinstance(item, (PipelineVariable, Condition)): for referenced_step in item._referenced_steps: - dependencies.add(Step._get_step_name_from_str(referenced_step, step_map)) - dependencies.update(Step._find_dependencies_in_step_arguments(item, step_map)) + dependencies.add(self._get_step_name_from_str(referenced_step, step_map)) + if isinstance(item, JsonGet): + self._validate_json_get_function(item, step_map) + dependencies.update(self._find_dependencies_in_step_arguments(item, step_map)) return dependencies + def _validate_json_get_function( + self, json_get: JsonGet, step_map: Dict[str, Union["Step", "StepCollection"]] + ): + """Validate the JsonGet function inputs.""" + property_file_reference = json_get.property_file + processing_step = step_map[json_get.step_name] + property_file = None + if isinstance(property_file_reference, str): + if not isinstance(processing_step, ProcessingStep): + raise ValueError( + f"Invalid JsonGet function {json_get.expr} in step '{self.name}'. JsonGet " + f"function can only be evaluated on processing step outputs." + ) + for file in processing_step.property_files: + if file.name == property_file_reference: + property_file = file + break + elif isinstance(property_file_reference, PropertyFile): + property_file = property_file_reference + if property_file is None: + raise ValueError( + f"Invalid JsonGet function {json_get.expr} in step '{self.name}'. Property file " + f"reference '{property_file_reference}' is undefined in step " + f"'{processing_step.name}'." + ) + property_file_output = None + if "ProcessingOutputConfig" in processing_step.arguments: + for output in processing_step.arguments["ProcessingOutputConfig"]["Outputs"]: + if output["OutputName"] == property_file.output_name: + property_file_output = output + if property_file_output is None: + raise ValueError( + f"Processing output name '{property_file.output_name}' defined in property file " + f"'{property_file.name}' not found in processing step '{processing_step.name}'." + ) + @staticmethod def _get_step_name_from_str( str_input: str, step_map: Dict[str, Union["Step", "StepCollection"]] @@ -766,7 +805,7 @@ def __init__( self.outputs = outputs self.job_arguments = job_arguments self.code = code - self.property_files = property_files + self.property_files = property_files or [] self.job_name = None self.kms_key = kms_key self.cache_config = cache_config diff --git a/tests/unit/sagemaker/local/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py index 6b62cd786b..4fdc589f8c 100644 --- a/tests/unit/sagemaker/local/test_local_entities.py +++ b/tests/unit/sagemaker/local/test_local_entities.py @@ -17,7 +17,14 @@ from mock import patch, Mock +from botocore.exceptions import ClientError + import sagemaker.local +from sagemaker.model import Model +from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.steps import CreateModelStep +from tests.unit.sagemaker.workflow.helpers import CustomStep @pytest.fixture(scope="session") @@ -188,3 +195,103 @@ def test_start_local_transform_job_from_remote_docker_host( calls = m_perform_request.call_args_list for call, endpoint in zip(calls, endpoints): assert call[0][0] == endpoint + + +@patch("sagemaker.local.pipeline.LocalPipelineExecutor.execute") +def test_start_local_pipeline(mock_local_pipeline_executor, sagemaker_local_session): + parameter = ParameterString("MyStr", default_value="test") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)], + sagemaker_session=sagemaker_local_session, + ) + local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) + + describe_pipeline_response = local_pipeline.describe() + assert describe_pipeline_response["PipelineArn"] == "MyPipeline" + assert describe_pipeline_response["CreationTime"] is not None + + mock_executor_return_value = sagemaker.local.entities._LocalPipelineExecution( + "execution-id", pipeline + ) + mock_executor_return_value.step_execution["MyStep"].status = "Executing" + mock_local_pipeline_executor.return_value = mock_executor_return_value + pipeline_execution = local_pipeline.start() + + describe_pipeline_execution_response = pipeline_execution.describe() + assert describe_pipeline_execution_response["PipelineArn"] == "MyPipeline" + assert describe_pipeline_execution_response["PipelineExecutionArn"] == "execution-id" + assert describe_pipeline_execution_response["CreationTime"] is not None + + list_steps_response = pipeline_execution.list_steps() + assert list_steps_response["PipelineExecutionSteps"][0]["StepName"] == "MyStep" + assert list_steps_response["PipelineExecutionSteps"][0]["StepStatus"] == "Executing" + + +def test_start_local_pipeline_with_unsupported_step_type(sagemaker_local_session): + step = CreateModelStep( + name="MyRegisterModelStep", + model=Model(image_uri="mock_image_uri"), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[step], + sagemaker_session=sagemaker_local_session, + ) + local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) + + with pytest.raises(ClientError) as e: + local_pipeline.start() + assert f"Step type {step.step_type.value} is not supported in local mode." in str(e.value) + + +def test_start_local_pipeline_with_undefined_parameter(sagemaker_local_session): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=sagemaker_local_session, + ) + local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) + with pytest.raises(ClientError) as error: + local_pipeline.start() + assert f"Parameter '{parameter.name}' is undefined." in str(error.value) + + +def test_start_local_pipeline_with_unknown_parameter(sagemaker_local_session): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=sagemaker_local_session, + ) + local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) + with pytest.raises(ClientError) as error: + local_pipeline.start( + PipelineParameters={"MyStr": "test-test", "UnknownParameterFoo": "foo"} + ) + assert "Unknown parameter 'UnknownParameterFoo'" in str(error.value) + + +def test_start_local_pipeline_with_wrong_parameter_type(sagemaker_local_session): + parameter = ParameterString("MyStr") + step = CustomStep(name="MyStep", input_data=parameter) + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[step], + sagemaker_session=sagemaker_local_session, + ) + local_pipeline = sagemaker.local.entities._LocalPipeline(pipeline) + with pytest.raises(ClientError) as error: + local_pipeline.start(PipelineParameters={"MyStr": True}) + assert ( + f"Unexpected type for parameter '{parameter.name}'. Expected " + f"{parameter.parameter_type.python_type} but found {type(True)}." in str(error.value) + ) diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 31a653292b..7ffa17d774 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -11,9 +11,12 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from mock import Mock, PropertyMock, patch import pytest +from mock import Mock, patch, PropertyMock + +from botocore.exceptions import ClientError + from sagemaker.estimator import Estimator from sagemaker.inputs import TrainingInput from sagemaker.debugger import ProfilerConfig @@ -31,22 +34,16 @@ ConditionOr, ) from sagemaker.workflow.fail_step import FailStep - -from botocore.exceptions import ClientError - from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline -from sagemaker.workflow.steps import CreateModelStep, StepTypeEnum -from sagemaker.model import Model from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.steps import ( ProcessingStep, TrainingStep, TransformStep, TransformInput, + StepTypeEnum, ) - -from sagemaker.local import LocalSession from sagemaker.local.pipeline import ( _ConditionStepExecutor, _FailStepExecutor, @@ -59,22 +56,43 @@ ) from sagemaker.local.entities import _LocalExecutionStatus, _LocalPipelineExecution from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.workflow.functions import Join +from sagemaker.workflow.functions import Join, JsonGet, PropertyFile +from sagemaker.local.local_session import LocalSession from tests.unit.sagemaker.workflow.helpers import CustomStep STRING_PARAMETER = ParameterString("MyStr", "DefaultParameter") +INSTANCE_COUNT_PIPELINE_PARAMETER = ParameterInteger(name="InstanceCount", default_value=6) INPUT_STEP = CustomStep(name="InputStep") IMAGE_URI = "fakeimage" ROLE = "DummyRole" BUCKET = "my-bucket" REGION = "us-west-2" INSTANCE_TYPE = "ml.m4.xlarge" -INSTANCE_COUNT_PIPELINE_PARAMETER = ParameterInteger(name="InstanceCount", default_value=6) - - -@pytest.fixture() -def local_sagemaker_session(): - return LocalSession() +PROPERTY_FILE_CONTENT = ( + "{" + ' "my-processing-output": {' + ' "nested_object1": {' + ' "metric1": 45.22,' + ' "metric2": 76' + " }," + ' "nested_object2": {' + ' "nested_list": [' + " {" + ' "list_object1": {' + ' "metric1": 55,' + ' "metric2": 66.34' + " }" + " }," + " {" + ' "list_object2": {' + ' "metric1": 33' + " }" + " }" + " ]" + " }" + " }" + "}" +) @pytest.fixture @@ -82,21 +100,6 @@ def role_arn(): return "arn:role" -@pytest.fixture -def boto_session(client): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - session_mock.client.return_value = client - - return session_mock - - @pytest.fixture def client(): """Mock client. @@ -114,6 +117,21 @@ def client(): return client_mock +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + @pytest.fixture def pipeline_session(boto_session, client): return PipelineSession( @@ -123,6 +141,11 @@ def pipeline_session(boto_session, client): ) +@pytest.fixture() +def local_sagemaker_session(boto_session): + return LocalSession(boto_session=boto_session, default_bucket="my-bucket") + + @pytest.fixture def training_step(pipeline_session): estimator = Estimator( @@ -225,54 +248,6 @@ def test_evaluate_parameter(local_sagemaker_session): assert evaluated_args["input_data"] == "test_string" -def test_evaluate_parameter_undefined(local_sagemaker_session, role_arn): - parameter = ParameterString("MyStr") - step = CustomStep(name="MyStep", input_data=parameter) - pipeline = Pipeline( - name="MyPipeline", - parameters=[parameter], - steps=[step], - sagemaker_session=local_sagemaker_session, - ) - with pytest.raises(ClientError) as error: - pipeline.create(role_arn, "test pipeline") - pipeline.start() - assert f"Parameter '{parameter.name}' is undefined." in str(error.value) - - -def test_evaluate_parameter_unknown(local_sagemaker_session, role_arn): - parameter = ParameterString("MyStr") - step = CustomStep(name="MyStep", input_data=parameter) - pipeline = Pipeline( - name="MyPipeline", - parameters=[parameter], - steps=[step], - sagemaker_session=local_sagemaker_session, - ) - with pytest.raises(ClientError) as error: - pipeline.create(role_arn, "test pipeline") - pipeline.start({"MyStr": "test-test", "UnknownParameterFoo": "foo"}) - assert "Unknown parameter 'UnknownParameterFoo'" in str(error.value) - - -def test_evaluate_parameter_wrong_type(local_sagemaker_session, role_arn): - parameter = ParameterString("MyStr") - step = CustomStep(name="MyStep", input_data=parameter) - pipeline = Pipeline( - name="MyPipeline", - parameters=[parameter], - steps=[step], - sagemaker_session=local_sagemaker_session, - ) - with pytest.raises(ClientError) as error: - pipeline.create(role_arn, "test pipeline") - pipeline.start({"MyStr": True}) - assert ( - f"Unexpected type for parameter '{parameter.name}'. Expected " - f"{parameter.parameter_type.python_type} but found {type(True)}." in str(error.value) - ) - - @pytest.mark.parametrize( "property_reference, expected", [ @@ -365,22 +340,250 @@ def test_evaluate_join_function(local_sagemaker_session, join_value, expected): assert evaluated_args["input_data"] == expected -def test_execute_unsupported_step_type(role_arn, local_sagemaker_session): - step = CreateModelStep( - name="MyRegisterModelStep", - model=Model(image_uri="mock_image_uri"), +@pytest.mark.parametrize( + "json_path_value, expected", + [ + ("my-processing-output.nested_object1.metric1", 45.22), + ("my-processing-output.nested_object1['metric2']", 76), + ("my-processing-output.nested_object2.nested_list[0].list_object1.metric1", 55), + ("my-processing-output.nested_object2.nested_list[0].list_object1['metric2']", 66.34), + ("my-processing-output.nested_object2.nested_list[1].list_object2.metric1", 33), + ], +) +@patch("sagemaker.session.Session.read_s3_file", return_value=PROPERTY_FILE_CONTENT) +def test_evaluate_json_get_function( + read_s3_file, local_sagemaker_session, json_path_value, expected +): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri="some_image_uri", + role="DummyRole", + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=local_sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="TestOutputName")], + property_files=[property_file], + ) + + step = CustomStep( + name="TestStep", + input_data=JsonGet( + step_name=processing_step.name, property_file=property_file, json_path=json_path_value + ), ) pipeline = Pipeline( name="MyPipeline", parameters=[STRING_PARAMETER], - steps=[step], + steps=[processing_step, step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution["inputProcessingStep"].properties = { + "ProcessingOutputConfig": { + "Outputs": [ + { + "OutputName": "TestOutputName", + "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, + } + ] + } + } + evaluated_args = LocalPipelineExecutor( + execution, local_sagemaker_session + ).evaluate_step_arguments(step) + assert evaluated_args["input_data"] == expected + + +def test_evaluate_json_get_function_processing_output_not_available(local_sagemaker_session): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri="some_image_uri", + role="DummyRole", + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=local_sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="TestOutputName")], + property_files=[property_file], + ) + step = CustomStep( + name="TestStep", + input_data=JsonGet( + step_name=processing_step.name, property_file=property_file, json_path="mse" + ), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[processing_step, step], sagemaker_session=local_sagemaker_session, ) - create_pipeline_response = pipeline.create(role_arn, "test pipeline") - assert create_pipeline_response["PipelineArn"] == "MyPipeline" - with pytest.raises(ClientError) as e: - pipeline.start() - assert f"Step type {step.step_type.value} is not supported in local mode." in str(e.value) + execution = _LocalPipelineExecution("my-execution", pipeline) + with pytest.raises(StepExecutionException) as e: + LocalPipelineExecutor(execution, local_sagemaker_session).evaluate_step_arguments(step) + assert f"Step '{processing_step.name}' does not yet contain processing outputs." in str(e.value) + + +@patch( + "sagemaker.session.Session.read_s3_file", + side_effect=ClientError({"Code": "NoSuchKey", "Message": "bad key"}, "GetObject"), +) +def test_evaluate_json_get_function_s3_client_error(read_s3_file, local_sagemaker_session): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri="some_image_uri", + role="DummyRole", + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=local_sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="TestOutputName")], + property_files=[property_file], + ) + step = CustomStep( + name="TestStep", + input_data=JsonGet( + step_name=processing_step.name, property_file=property_file, json_path="mse" + ), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[processing_step, step], + sagemaker_session=local_sagemaker_session, + ) + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution["inputProcessingStep"].properties = { + "ProcessingOutputConfig": { + "Outputs": [ + { + "OutputName": "TestOutputName", + "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, + } + ] + } + } + with pytest.raises(StepExecutionException) as e: + LocalPipelineExecutor(execution, local_sagemaker_session).evaluate_step_arguments(step) + assert f"Received an error while file reading file '{property_file.path}' from S3" in str( + e.value + ) + + +@patch("sagemaker.session.Session.read_s3_file", return_value="['invalid_json']") +def test_evaluate_json_get_function_bad_json_in_property_file( + read_s3_file, local_sagemaker_session +): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri="some_image_uri", + role="DummyRole", + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=local_sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="TestOutputName")], + property_files=[property_file], + ) + step = CustomStep( + name="TestStep", + input_data=JsonGet( + step_name=processing_step.name, property_file=property_file, json_path="mse" + ), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[processing_step, step], + sagemaker_session=local_sagemaker_session, + ) + + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution["inputProcessingStep"].properties = { + "ProcessingOutputConfig": { + "Outputs": [ + { + "OutputName": "TestOutputName", + "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, + } + ] + } + } + with pytest.raises(StepExecutionException) as e: + LocalPipelineExecutor(execution, local_sagemaker_session).evaluate_step_arguments(step) + assert f"Contents of property file '{property_file.name}' are not in valid JSON format." in str( + e.value + ) + + +@patch("sagemaker.session.Session.read_s3_file", return_value=PROPERTY_FILE_CONTENT) +def test_evaluate_json_get_function_invalid_json_path(read_s3_file, local_sagemaker_session): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri="some_image_uri", + role="DummyRole", + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=local_sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="TestOutputName")], + property_files=[property_file], + ) + step = CustomStep( + name="TestStep", + input_data=JsonGet( + step_name=processing_step.name, + property_file=property_file, + json_path="some.json.path[1].does.not['exist']", + ), + ) + pipeline = Pipeline( + name="MyPipeline", + parameters=[STRING_PARAMETER], + steps=[processing_step, step], + sagemaker_session=local_sagemaker_session, + ) + execution = _LocalPipelineExecution("my-execution", pipeline) + execution.step_execution["inputProcessingStep"].properties = { + "ProcessingOutputConfig": { + "Outputs": [ + { + "OutputName": "TestOutputName", + "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, + } + ] + } + } + with pytest.raises(StepExecutionException) as e: + LocalPipelineExecutor(execution, local_sagemaker_session).evaluate_step_arguments(step) + assert "Invalid json path 'some.json.path[1].does.not['exist']'" in str(e.value) @pytest.mark.parametrize( @@ -644,14 +847,12 @@ def test_execute_pipeline_condition_step_test_conditions( execution.step_execution.get(succeeded_step).status == _LocalExecutionStatus.SUCCEEDED.value ) + assert execution.step_execution.get(succeeded_step).name == succeeded_step assert execution.step_execution.get(succeeded_step).properties != {} assert execution.step_execution.get(succeeded_step).failure_reason is None for executing_step in executing_steps: - assert ( - execution.step_execution.get(executing_step).status - == _LocalExecutionStatus.STARTING.value - ) + assert execution.step_execution.get(executing_step).name == executing_step assert execution.step_execution.get(executing_step).properties == {} assert execution.step_execution.get(executing_step).failure_reason is None @@ -719,7 +920,7 @@ def test_pipeline_execution_condition_step_execution_path( actual_path = [] for step_name, step_execution in execution.step_execution.items(): - if step_execution.status != _LocalExecutionStatus.STARTING.value: + if step_execution.status is not None: actual_path.append(step_name) assert actual_path == expected_path @@ -752,9 +953,7 @@ def test_condition_step_incompatible_types(local_sagemaker_session): + "type [] are not of the same type." in execution.failure_reason ) - assert execution.step_execution["stepA"].status == _LocalExecutionStatus.STARTING.value - assert execution.step_execution["stepB"].status == _LocalExecutionStatus.STARTING.value - execution.step_execution["stepCondition"].status == _LocalExecutionStatus.FAILED.value + assert execution.step_execution["stepCondition"].status == _LocalExecutionStatus.FAILED.value @patch("sagemaker.local.local_session._LocalTrainingJob") diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index 4b5801d971..cdae087c00 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -20,6 +20,11 @@ from tests.unit import DATA_DIR import sagemaker +from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.pipeline import Pipeline +from tests.unit.sagemaker.workflow.helpers import CustomStep +from sagemaker.local.local_session import LocalSession +from sagemaker.local.entities import _LocalPipelineExecution OK_RESPONSE = urllib3.HTTPResponse() @@ -872,3 +877,73 @@ def test_invoke_local_endpoint_with_remote_docker_host( Body, "local_endpoint" ) m_request.assert_called_with("POST", url, body=Body, preload_content=False, headers={}) + + +def test_create_describe_update_pipeline(): + parameter = ParameterString("MyStr", default_value="test") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)], + sagemaker_session=LocalSession(), + ) + pipeline.create("dummy-role", "pipeline-description") + + pipeline_describe_response1 = pipeline.describe() + assert pipeline_describe_response1["PipelineArn"] == "MyPipeline" + assert pipeline_describe_response1["PipelineDefinition"] == pipeline.definition() + assert pipeline_describe_response1["PipelineDescription"] == "pipeline-description" + + pipeline.update("dummy-role", "pipeline-description-2") + pipeline_describe_response2 = pipeline.describe() + assert pipeline_describe_response2["PipelineDescription"] == "pipeline-description-2" + assert ( + pipeline_describe_response2["CreationTime"] + != pipeline_describe_response2["LastModifiedTime"] + ) + + +@patch("sagemaker.local.pipeline.LocalPipelineExecutor.execute") +def test_start_pipeline(mock_local_pipeline_executor): + parameter = ParameterString("MyStr", default_value="test") + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)], + sagemaker_session=LocalSession(), + ) + pipeline.create("dummy-role", "pipeline-description") + mock_local_pipeline_executor.return_value = _LocalPipelineExecution("execution-id", pipeline) + + pipeline_execution = pipeline.start() + pipeline_execution_describe_response = pipeline_execution.describe() + assert pipeline_execution_describe_response["PipelineArn"] == "MyPipeline" + assert pipeline_execution_describe_response["PipelineExecutionArn"] == "execution-id" + assert pipeline_execution_describe_response["CreationTime"] is not None + + +def test_update_undefined_pipeline(): + session = LocalSession() + parameter = ParameterString("MyStr", default_value="test") + pipeline = Pipeline( + name="UndefinedPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStep", input_data=parameter)], + sagemaker_session=session, + ) + + with pytest.raises(ClientError) as e: + session.sagemaker_client.update_pipeline(pipeline, "some_description") + assert "Pipeline {} does not exist".format(pipeline.name) in str(e.value) + + +def test_describe_undefined_pipeline(): + with pytest.raises(ClientError) as e: + LocalSession().sagemaker_client.describe_pipeline("UndefinedPipeline") + assert "Pipeline UndefinedPipeline does not exist" in str(e.value) + + +def test_start_undefined_pipeline(): + with pytest.raises(ClientError) as e: + LocalSession().sagemaker_client.start_pipeline_execution("UndefinedPipeline") + assert "Pipeline UndefinedPipeline does not exist" in str(e.value) diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index 4bce43704e..b2a94740b4 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -116,3 +116,52 @@ def test_get_docker_host(m_subprocess): cmd, stdout=m_subprocess.PIPE, stderr=m_subprocess.PIPE ) assert host == endpoint["result"] + + +@pytest.mark.parametrize( + "json_path, expected", + [ + ("Name", "John Doe"), + ("Age", 31), + ("Experiences[0].Company", "Foo Inc."), + ("Experiences[0].Tenure", 5), + ("Experiences[0].Projects[0]['XYZ project']", "Backend Rest Api development"), + ("Experiences[0].Projects[1]['ABC project']", "Data migration"), + ("Experiences[1].Company", "Bar Ltd."), + ("Experiences[1].Tenure", 2), + ], +) +def test_get_using_dot_notation(json_path, expected): + resume = { + "Name": "John Doe", + "Age": 31, + "Experiences": [ + { + "Company": "Foo Inc.", + "Role": "SDE", + "Tenure": 5, + "Projects": [ + {"XYZ project": "Backend Rest Api development"}, + {"ABC project": "Data migration"}, + ], + }, + {"Company": "Bar Ltd.", "Role": "Web developer", "Tenure": 2}, + ], + } + actual = sagemaker.local.utils.get_using_dot_notation(resume, json_path) + assert actual == expected + + +def test_get_using_dot_notation_type_error(): + with pytest.raises(TypeError): + sagemaker.local.utils.get_using_dot_notation({"foo": "bar"}, "foo.test") + + +def test_get_using_dot_notation_key_error(): + with pytest.raises(KeyError): + sagemaker.local.utils.get_using_dot_notation({"foo": {"bar": 1}}, "foo.test") + + +def test_get_using_dot_notation_index_error(): + with pytest.raises(IndexError): + sagemaker.local.utils.get_using_dot_notation({"foo": ["bar"]}, "foo[1]") diff --git a/tests/unit/sagemaker/workflow/helpers.py b/tests/unit/sagemaker/workflow/helpers.py index ebc3bbd959..aa665c6da3 100644 --- a/tests/unit/sagemaker/workflow/helpers.py +++ b/tests/unit/sagemaker/workflow/helpers.py @@ -14,7 +14,7 @@ from __future__ import absolute_import from sagemaker.workflow.properties import Properties -from sagemaker.workflow.steps import Step, StepTypeEnum +from sagemaker.workflow.steps import ConfigurableRetryStep, StepTypeEnum from sagemaker.workflow.step_collections import StepCollection @@ -37,11 +37,19 @@ def ordered(obj): return obj -class CustomStep(Step): - def __init__(self, name, input_data=None, display_name=None, description=None, depends_on=None): +class CustomStep(ConfigurableRetryStep): + def __init__( + self, + name, + input_data=None, + display_name=None, + description=None, + depends_on=None, + retry_policies=None, + ): self.input_data = input_data super(CustomStep, self).__init__( - name, display_name, description, StepTypeEnum.TRAINING, depends_on + name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies ) # for testing property reference, we just use DescribeTrainingJobResponse shape here. self._properties = Properties(name, shape_name="DescribeTrainingJobResponse") diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index 459b12e157..327443aee7 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -32,7 +32,7 @@ ) from sagemaker.workflow.step_collections import StepCollection from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep -from sagemaker.local import LocalSession +from sagemaker.local.local_session import LocalSession @pytest.fixture diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 1a61d2088b..cd8b7522d1 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -45,7 +45,7 @@ ) from sagemaker.network import NetworkConfig from sagemaker.transformer import Transformer -from sagemaker.workflow.functions import Join +from sagemaker.workflow.functions import Join, JsonGet from sagemaker.workflow.pipeline import Pipeline, PipelineGraph from sagemaker.workflow.properties import Properties, PropertyFile from sagemaker.workflow.parameters import ParameterString, ParameterInteger, ParameterBoolean @@ -57,20 +57,19 @@ ) from sagemaker.workflow.steps import ( ProcessingStep, - ConfigurableRetryStep, - StepTypeEnum, TrainingStep, TuningStep, TransformStep, CreateModelStep, CacheConfig, ) +from sagemaker.workflow.pipeline_context import _JobStepArguments from sagemaker.pipeline import PipelineModel from sagemaker.sparkml import SparkMLModel from sagemaker.predictor import Predictor from sagemaker.model import FrameworkModel from tests.unit import DATA_DIR -from tests.unit.sagemaker.workflow.helpers import ordered +from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep DUMMY_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -81,22 +80,6 @@ MODEL_NAME = "gisele" -class CustomStep(ConfigurableRetryStep): - def __init__(self, name, display_name=None, description=None, retry_policies=None): - super(CustomStep, self).__init__( - name, StepTypeEnum.TRAINING, display_name, description, None, retry_policies - ) - self._properties = Properties(name) - - @property - def arguments(self): - return dict() - - @property - def properties(self): - return self._properties - - class DummyFrameworkModel(FrameworkModel): def __init__(self, sagemaker_session, **kwargs): super(DummyFrameworkModel, self).__init__( @@ -1466,3 +1449,88 @@ def test_multi_algo_tuning_step(sagemaker_session): ], }, } + + +def test_pipeline_dag_json_get_bad_step_type(sagemaker_session): + training_step = TrainingStep( + name="inputTrainingStep", + step_args=_JobStepArguments(sagemaker_session.train.__name__, {"arg1": "value"}), + ) + json_get_function = JsonGet( + step_name=training_step.name, property_file="my-property-file", json_path="mse" + ) + custom_step = CustomStep(name="TestStep", input_data=json_get_function) + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[training_step, custom_step], + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as e: + PipelineGraph.from_pipeline(pipeline) + assert ( + f"Invalid JsonGet function {json_get_function.expr} in step '{custom_step.name}'. " + f"JsonGet function can only be evaluated on processing step outputs." in str(e.value) + ) + + +def test_pipeline_dag_json_get_undefined_property_file(sagemaker_session): + processing_step = ProcessingStep( + name="inputProcessingStep", + step_args=_JobStepArguments(sagemaker_session.process.__name__, {"arg1": "value"}), + ) + + json_get_function = JsonGet( + step_name=processing_step.name, property_file="undefined-property-file", json_path="mse" + ) + custom_step = CustomStep(name="TestStep", input_data=json_get_function) + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[processing_step, custom_step], + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as e: + PipelineGraph.from_pipeline(pipeline) + assert ( + f"Invalid JsonGet function {json_get_function.expr} in step '{custom_step.name}'. Property " + f"file reference '{json_get_function.property_file}' is undefined in step " + f"'{processing_step.name}'." in str(e.value) + ) + + +def test_pipeline_dag_json_get_wrong_processing_output_name(sagemaker_session): + property_file = PropertyFile( + name="my-property-file", output_name="TestOutputName", path="processing_output.json" + ) + processor = Processor( + image_uri=IMAGE_URI, + role=ROLE, + instance_count=1, + instance_type="c4.4xlarge", + sagemaker_session=sagemaker_session, + ) + processing_step = ProcessingStep( + name="inputProcessingStep", + processor=processor, + outputs=[ProcessingOutput(output_name="SomeOtherTestOutputName")], + property_files=[property_file], + ) + + json_get_function = JsonGet( + step_name=processing_step.name, property_file=property_file.name, json_path="mse" + ) + custom_step = CustomStep(name="TestStep", input_data=json_get_function) + pipeline = Pipeline( + name="MyPipeline", + parameters=[], + steps=[processing_step, custom_step], + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as e: + PipelineGraph.from_pipeline(pipeline) + assert ( + f"Processing output name '{property_file.output_name}' defined in property file " + f"'{property_file.name}' not found in processing step '{processing_step.name}'." + in str(e.value) + ) From e0b1e3f839cbc67736251f1a440b72831758fbdf Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Tue, 26 Jul 2022 15:15:59 -0700 Subject: [PATCH 171/526] change: add local mode integ tests Co-authored-by: Namrata Madan --- src/sagemaker/local/__init__.py | 1 - src/sagemaker/local/entities.py | 59 ++-- src/sagemaker/local/local_session.py | 18 +- src/sagemaker/local/pipeline.py | 92 +++--- src/sagemaker/local/utils.py | 57 ++-- src/sagemaker/workflow/pipeline.py | 3 +- src/sagemaker/workflow/pipeline_context.py | 18 +- tests/conftest.py | 7 +- tests/data/mxnet_mnist/code/evaluation.py | 13 + tests/integ/test_local_mode.py | 267 +++++++++++++++++- .../sagemaker/local/test_local_entities.py | 8 +- .../sagemaker/local/test_local_pipeline.py | 26 +- .../unit/sagemaker/local/test_local_utils.py | 6 +- 13 files changed, 452 insertions(+), 123 deletions(-) create mode 100644 tests/data/mxnet_mnist/code/evaluation.py diff --git a/src/sagemaker/local/__init__.py b/src/sagemaker/local/__init__.py index 7bb8cf224c..1cd1b222e3 100644 --- a/src/sagemaker/local/__init__.py +++ b/src/sagemaker/local/__init__.py @@ -18,5 +18,4 @@ LocalSagemakerClient, LocalSagemakerRuntimeClient, LocalSession, - LocalPipelineSession, ) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index cfc0307ff6..3ee0c41e28 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -200,6 +200,7 @@ def __init__(self, container): self.start_time = None self.end_time = None self.environment = None + self.training_job_name = "" def start(self, input_data_config, output_data_config, hyperparameters, environment, job_name): """Starts a local training job. @@ -244,10 +245,13 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm ) self.end_time = datetime.datetime.now() self.state = self._COMPLETED + self.training_job_name = job_name def describe(self): """Placeholder docstring""" response = { + "TrainingJobName": self.training_job_name, + "TrainingJobArn": _UNUSED_ARN, "ResourceConfig": {"InstanceCount": self.container.instance_count}, "TrainingJobStatus": self.state, "TrainingStartTime": self.start_time, @@ -640,9 +644,8 @@ def __init__( self.local_session = local_session or LocalSession() self.pipeline = pipeline self.pipeline_description = pipeline_description - now_time = datetime.datetime.now() - self.creation_time = now_time - self.last_modified_time = now_time + self.creation_time = datetime.datetime.now().timestamp() + self.last_modified_time = self.creation_time def describe(self): """Describe Pipeline""" @@ -666,6 +669,13 @@ def start(self, **kwargs): execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs) self._executions[execution_id] = execution + logger.info( + "Starting execution for pipeline %s. Execution ID is %s", + self.pipeline.name, + execution_id, + ) + self.last_modified_time = datetime.datetime.now().timestamp() + return LocalPipelineExecutor(execution, self.local_session).execute() @@ -686,17 +696,18 @@ def __init__( self.pipeline_execution_display_name = PipelineExecutionDisplayName self.status = _LocalExecutionStatus.EXECUTING.value self.failure_reason = None - self.creation_time = datetime.datetime.now() + self.creation_time = datetime.datetime.now().timestamp() + self.last_modified_time = self.creation_time self.step_execution = {} self._initialize_step_execution(self.pipeline.steps) self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters) - self.blockout_steps = {} + self._blocked_steps = {} def describe(self): """Describe Pipeline Execution.""" response = { "CreationTime": self.creation_time, - "LastModifiedTime": self.creation_time, + "LastModifiedTime": self.last_modified_time, "FailureReason": self.failure_reason, "PipelineArn": self.pipeline.name, "PipelineExecutionArn": self.pipeline_execution_name, @@ -720,23 +731,33 @@ def list_steps(self): def update_execution_success(self): """Mark execution as succeeded.""" self.status = _LocalExecutionStatus.SUCCEEDED.value + self.last_modified_time = datetime.datetime.now().timestamp() + logger.info("Pipeline execution %s SUCCEEDED", self.pipeline_execution_name) def update_execution_failure(self, step_name, failure_message): """Mark execution as failed.""" self.status = _LocalExecutionStatus.FAILED.value self.failure_reason = f"Step {step_name} failed with message: {failure_message}" - logger.error("Pipeline execution failed because step %s failed.", step_name) + self.last_modified_time = datetime.datetime.now().timestamp() + logger.info( + "Pipeline execution %s FAILED because step %s failed.", + self.pipeline_execution_name, + step_name, + ) def update_step_properties(self, step_name, step_properties): """Update pipeline step execution output properties.""" self.step_execution.get(step_name).update_step_properties(step_properties) + logger.info("Pipeline step %s SUCCEEDED.", step_name) def update_step_failure(self, step_name, failure_message): """Mark step_name as failed.""" self.step_execution.get(step_name).update_step_failure(failure_message) + logger.info("Pipeline step %s FAILED. Failure message is: %s", step_name, failure_message) def mark_step_executing(self, step_name): """Update pipelines step's status to EXECUTING and start_time to now.""" + logger.info("Starting pipeline step: %s", step_name) self.step_execution.get(step_name).mark_step_executing() def _initialize_step_execution(self, steps): @@ -749,6 +770,7 @@ def _initialize_step_execution(self, steps): StepTypeEnum.TRANSFORM, StepTypeEnum.CONDITION, StepTypeEnum.FAIL, + StepTypeEnum.CREATE_MODEL, ) for step in steps: @@ -828,29 +850,28 @@ def __init__( StepTypeEnum.TRAINING: self._construct_training_metadata, StepTypeEnum.PROCESSING: self._construct_processing_metadata, StepTypeEnum.TRANSFORM: self._construct_transform_metadata, + StepTypeEnum.CREATE_MODEL: self._construct_model_metadata, StepTypeEnum.CONDITION: self._construct_condition_metadata, StepTypeEnum.FAIL: self._construct_fail_metadata, } def update_step_properties(self, properties): """Update pipeline step execution output properties.""" - logger.info("Successfully completed step %s.", self.name) self.properties = deepcopy(properties) self.status = _LocalExecutionStatus.SUCCEEDED.value - self.end_time = datetime.datetime.now() + self.end_time = datetime.datetime.now().timestamp() def update_step_failure(self, failure_message): """Update pipeline step execution failure status and message.""" - logger.error(failure_message) self.failure_reason = failure_message self.status = _LocalExecutionStatus.FAILED.value - self.end_time = datetime.datetime.now() + self.end_time = datetime.datetime.now().timestamp() raise StepExecutionException(self.name, failure_message) def mark_step_executing(self): """Update pipelines step's status to EXECUTING and start_time to now""" self.status = _LocalExecutionStatus.EXECUTING.value - self.start_time = datetime.datetime.now() + self.start_time = datetime.datetime.now().timestamp() def to_list_steps_response(self): """Convert to response dict for list_steps calls.""" @@ -875,23 +896,27 @@ def _construct_metadata(self): def _construct_training_metadata(self): """Construct training job metadata response.""" - return {"TrainingJob": {"Arn": self.properties.TrainingJobArn}} + return {"TrainingJob": {"Arn": self.properties["TrainingJobName"]}} def _construct_processing_metadata(self): """Construct processing job metadata response.""" - return {"ProcessingJob": {"Arn": self.properties.ProcessingJobArn}} + return {"ProcessingJob": {"Arn": self.properties["ProcessingJobName"]}} def _construct_transform_metadata(self): """Construct transform job metadata response.""" - return {"TransformJob": {"Arn": self.properties.TransformJobArn}} + return {"TransformJob": {"Arn": self.properties["TransformJobName"]}} + + def _construct_model_metadata(self): + """Construct create model step metadata response.""" + return {"Model": {"Arn": self.properties["ModelName"]}} def _construct_condition_metadata(self): """Construct condition step metadata response.""" - return {"Condition": {"Outcome": self.properties.Outcome}} + return {"Condition": {"Outcome": self.properties["Outcome"]}} def _construct_fail_metadata(self): """Construct fail step metadata response.""" - return {"Fail": {"ErrorMessage": self.properties.ErrorMessage}} + return {"Fail": {"ErrorMessage": self.properties["ErrorMessage"]}} class _LocalExecutionStatus(enum.Enum): diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 8d48bfb5ba..d785abb154 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -448,7 +448,9 @@ def update_pipeline( } raise ClientError(error_response, "update_pipeline") LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description - LocalSagemakerClient._pipelines[pipeline.name].last_modified_time = datetime.now() + LocalSagemakerClient._pipelines[ + pipeline.name + ].last_modified_time = datetime.now().timestamp() return {"PipelineArn": pipeline.name} def describe_pipeline(self, PipelineName): @@ -715,17 +717,3 @@ def __init__(self, fileUri, content_type=None): if content_type is not None: self.config["ContentType"] = content_type - - -class LocalPipelineSession(LocalSession): - """Class representing a local session for SageMaker Pipelines executions.""" - - def __init__( - self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False - ): - super().__init__( - boto_session=boto_session, - default_bucket=default_bucket, - s3_endpoint_url=s3_endpoint_url, - disable_local_code=disable_local_code, - ) diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index 7da6e83165..0c77f2b967 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -14,7 +14,6 @@ from __future__ import absolute_import from abc import ABC, abstractmethod -import logging import json from copy import deepcopy from datetime import datetime @@ -32,8 +31,8 @@ from sagemaker.local.exceptions import StepExecutionException from sagemaker.local.utils import get_using_dot_notation from sagemaker.utils import unique_name_from_base +from sagemaker.s3 import parse_s3_url, s3_path_join -logger = logging.getLogger(__name__) PRIMITIVES = (str, int, bool, float) BINARY_CONDITION_TYPES = ( @@ -59,14 +58,14 @@ def __init__(self, execution, sagemaker_session): self.execution = execution self.pipeline_dag = PipelineGraph.from_pipeline(self.execution.pipeline) self.local_sagemaker_client = self.sagemaker_session.sagemaker_client - self.blockout_steps = set() + self._blocked_steps = set() self._step_executor_factory = _StepExecutorFactory(self) def execute(self): """Execute a local pipeline.""" try: for step in self.pipeline_dag: - if step.name not in self.blockout_steps: + if step.name not in self._blocked_steps: self._execute_step(step) except StepExecutionException as e: self.execution.update_execution_failure(e.step_name, e.message) @@ -110,7 +109,7 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name): value = self.execution.pipeline_parameters.get(pipeline_variable.name) elif isinstance(pipeline_variable, Join): evaluated = [ - self.evaluate_pipeline_variable(v, step_name) for v in pipeline_variable.values + str(self.evaluate_pipeline_variable(v, step_name)) for v in pipeline_variable.values ] value = pipeline_variable.on.join(evaluated) elif isinstance(pipeline_variable, Properties): @@ -134,7 +133,7 @@ def _evaluate_property_reference(self, pipeline_variable, step_name): referenced_step_name = pipeline_variable.step_name step_properties = self.execution.step_execution.get(referenced_step_name).properties return get_using_dot_notation(step_properties, pipeline_variable.path) - except (KeyError, IndexError, TypeError): + except ValueError: self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.") def _evaluate_execution_variable(self, pipeline_variable): @@ -178,14 +177,13 @@ def _evaluate_json_get_function(self, pipeline_variable, step_name): step_name, f"Step '{pipeline_variable.step_name}' does not yet contain processing outputs.", ) - processing_output_s3_bucket = None - for output in processing_step_response["ProcessingOutputConfig"]["Outputs"]: - if output["OutputName"] == property_file.output_name: - processing_output_s3_bucket = output["S3Output"]["S3Uri"] - break + processing_output_s3_bucket = processing_step_response["ProcessingOutputConfig"]["Outputs"][ + property_file.output_name + ]["S3Output"]["S3Uri"] try: + s3_bucket, s3_key_prefix = parse_s3_url(processing_output_s3_bucket) file_content = self.sagemaker_session.read_s3_file( - processing_output_s3_bucket, property_file.path + s3_bucket, s3_path_join(s3_key_prefix, property_file.path) ) file_json = json.loads(file_content) return get_using_dot_notation(file_json, pipeline_variable.json_path) @@ -200,7 +198,7 @@ def _evaluate_json_get_function(self, pipeline_variable, step_name): step_name, f"Contents of property file '{property_file.name}' are not in valid JSON format.", ) - except (KeyError, IndexError, TypeError): + except ValueError: self.execution.update_step_failure( step_name, f"Invalid json path '{pipeline_variable.json_path}'" ) @@ -228,19 +226,18 @@ def _convert_list_to_dict(self, dictionary: dict, path_to_list: str, reducing_ke """ try: - rest = get_using_dot_notation(dictionary, path_to_list) - except (KeyError, IndexError, TypeError): - raise RuntimeError("%s does not exist in %s" % path_to_list, dictionary) - if not isinstance(rest, list): + list_to_convert = get_using_dot_notation(dictionary, path_to_list) + except ValueError: + raise RuntimeError(f"{path_to_list} does not exist in {dictionary}") + if not isinstance(list_to_convert, list): raise RuntimeError( - "%s of type %s is not a list to be converted into a dictionary!" % rest, - type(rest), + f"Element at path {path_to_list} is not a list. Actual type {type(list_to_convert)}" ) converted_map = {} - for element in rest: + for element in list_to_convert: if not isinstance(element, dict): raise RuntimeError( - "Cannot convert element of type %s into dictionary entry" % type(element) + f"Cannot convert element of type {type(element)} into dictionary entry" ) converted_map[element[reducing_key]] = element return converted_map @@ -277,12 +274,19 @@ def execute(self): job_describe_response = ( self.pipline_executor.local_sagemaker_client.describe_processing_job(job_name) ) - job_describe_response["ProcessingOutputConfig"]["Outputs"] = self._convert_list_to_dict( - job_describe_response, "ProcessingOutputConfig.Outputs", "OutputName" - ) - job_describe_response["ProcessingInputs"] = self._convert_list_to_dict( - job_describe_response, "ProcessingInputs", "InputName" - ) + if ( + "ProcessingOutputConfig" in job_describe_response + and "Outputs" in job_describe_response["ProcessingOutputConfig"] + ): + job_describe_response["ProcessingOutputConfig"][ + "Outputs" + ] = self._convert_list_to_dict( + job_describe_response, "ProcessingOutputConfig.Outputs", "OutputName" + ) + if "ProcessingInputs" in job_describe_response: + job_describe_response["ProcessingInputs"] = self._convert_list_to_dict( + job_describe_response, "ProcessingInputs", "InputName" + ) return job_describe_response except Exception as e: # pylint: disable=W0703 @@ -296,13 +300,13 @@ class _ConditionStepExecutor(_StepExecutor): """Executor class to execute ConditionStep locally""" def execute(self): - def _blockout_all_downstream_steps(steps: List[Step]): - step_to_blockout = set() + def _block_all_downstream_steps(steps: List[Step]): + steps_to_block = set() for step in steps: - step_to_blockout.update( + steps_to_block.update( self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step.name) ) - self.pipline_executor.blockout_steps.update(step_to_blockout) + self.pipline_executor._blocked_steps.update(steps_to_block) if_steps = self.step.if_steps else_steps = self.step.else_steps @@ -313,9 +317,9 @@ def _blockout_all_downstream_steps(steps: List[Step]): outcome = self._evaluate_conjunction(step_only_arguments["Conditions"]) if not outcome: - _blockout_all_downstream_steps(if_steps) + _block_all_downstream_steps(if_steps) else: - _blockout_all_downstream_steps(else_steps) + _block_all_downstream_steps(else_steps) return dict(Outcome=outcome) @@ -356,7 +360,7 @@ def _resolve_condition(self, condition: dict) -> bool: elif condition_type == ConditionTypeEnum.IN.value: outcome = self._resolve_in_condition(condition) else: - raise NotImplementedError("Condition of type [%s] is not supported." % condition_type) + raise NotImplementedError(f"Condition of type [{condition_type}] is not supported.") return outcome @@ -396,7 +400,7 @@ def _resolve_binary_condition(self, binary_condition: dict, binary_condition_typ outcome = left_value <= right_value else: raise NotImplementedError( - "Binary condition of type [%s] is not supported" % binary_condition_type + f"Binary condition of type [{binary_condition_type}] is not supported" ) return outcome @@ -470,6 +474,22 @@ def execute(self): ) +class _CreateModelStepExecutor(_StepExecutor): + """Executor class to execute CreateModelStep locally""" + + def execute(self): + model_name = unique_name_from_base(self.step.name) + step_arguments = self.pipline_executor.evaluate_step_arguments(self.step) + try: + self.pipline_executor.local_sagemaker_client.create_model(model_name, **step_arguments) + return self.pipline_executor.local_sagemaker_client.describe_model(model_name) + except Exception as e: # pylint: disable=W0703 + self.pipline_executor.execution.update_step_failure( + self.step.name, + f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + ) + + class _FailStepExecutor(_StepExecutor): """Executor class to execute FailStep locally""" @@ -502,6 +522,8 @@ def get(self, step: Step) -> _StepExecutor: step_executor = _ProcessingStepExecutor(self.pipeline_executor, step) elif step_type == StepTypeEnum.TRANSFORM: step_executor = _TransformStepExecutor(self.pipeline_executor, step) + elif step_type == StepTypeEnum.CREATE_MODEL: + step_executor = _CreateModelStepExecutor(self.pipeline_executor, step) elif step_type == StepTypeEnum.FAIL: step_executor = _FailStepExecutor(self.pipeline_executor, step) elif step_type == StepTypeEnum.CONDITION: diff --git a/src/sagemaker/local/utils.py b/src/sagemaker/local/utils.py index cd0c45b2ea..686a7a1481 100644 --- a/src/sagemaker/local/utils.py +++ b/src/sagemaker/local/utils.py @@ -166,32 +166,35 @@ def get_using_dot_notation(dictionary, keys): Nested object within dictionary as defined by "keys" Raises: - KeyError/IndexError/TypeError if the provided key does not exist in input dictionary + ValueError if the provided key does not exist in input dictionary """ - if keys is None: - return dictionary - split_keys = keys.split(".", 1) - key = split_keys[0] - rest = None - if len(split_keys) > 1: - rest = split_keys[1] - bracket_accessors = re.findall(r"\[(.+?)]", key) - if bracket_accessors: - pre_bracket_key = key.split("[", 1)[0] - inner_dict = dictionary[pre_bracket_key] - else: - inner_dict = dictionary[key] - for bracket_accessor in bracket_accessors: - if ( - bracket_accessor.startswith("'") - and bracket_accessor.endswith("'") - or bracket_accessor.startswith('"') - and bracket_accessor.endswith('"') - ): - # key accessor - inner_key = bracket_accessor[1:-1] + try: + if keys is None: + return dictionary + split_keys = keys.split(".", 1) + key = split_keys[0] + rest = None + if len(split_keys) > 1: + rest = split_keys[1] + bracket_accessors = re.findall(r"\[(.+?)]", key) + if bracket_accessors: + pre_bracket_key = key.split("[", 1)[0] + inner_dict = dictionary[pre_bracket_key] else: - # list accessor - inner_key = int(bracket_accessor) - inner_dict = inner_dict[inner_key] - return get_using_dot_notation(inner_dict, rest) + inner_dict = dictionary[key] + for bracket_accessor in bracket_accessors: + if ( + bracket_accessor.startswith("'") + and bracket_accessor.endswith("'") + or bracket_accessor.startswith('"') + and bracket_accessor.endswith('"') + ): + # key accessor + inner_key = bracket_accessor[1:-1] + else: + # list accessor + inner_key = int(bracket_accessor) + inner_dict = inner_dict[inner_key] + return get_using_dot_notation(inner_dict, rest) + except (KeyError, IndexError, TypeError): + raise ValueError(f"{keys} does not exist in input dictionary.") diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 2c4631e8f8..51848f0386 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -43,7 +43,6 @@ from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.utilities import list_to_request -from sagemaker.workflow.pipeline_context import LocalPipelineSession logger = logging.getLogger(__name__) @@ -164,7 +163,7 @@ def _create_args( # If pipeline definition is large, upload to S3 bucket and # provide PipelineDefinitionS3Location to request instead. if len(pipeline_definition.encode("utf-8")) < 1024 * 100 or isinstance( - self.sagemaker_session, (LocalSession, LocalPipelineSession) + self.sagemaker_session, LocalSession ): kwargs["PipelineDefinition"] = pipeline_definition else: diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index 3a9feb65e4..bcf4cbe2d5 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -19,7 +19,7 @@ from typing import Dict, Optional from sagemaker.session import Session, SessionSettings -from sagemaker.local import LocalPipelineSession +from sagemaker.local import LocalSession class _StepArguments: @@ -152,6 +152,20 @@ def init_model_step_arguments(self, model): self._context = _ModelStepArguments(model) +class LocalPipelineSession(LocalSession, PipelineSession): + """Class representing a local session for SageMaker Pipelines executions.""" + + def __init__( + self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False + ): + super().__init__( + boto_session=boto_session, + default_bucket=default_bucket, + s3_endpoint_url=s3_endpoint_url, + disable_local_code=disable_local_code, + ) + + def runnable_by_pipeline(run_func): """A convenient Decorator @@ -171,7 +185,7 @@ def runnable_by_pipeline(run_func): @wraps(run_func) def wrapper(*args, **kwargs): self_instance = args[0] - if isinstance(self_instance.sagemaker_session, (PipelineSession, LocalPipelineSession)): + if isinstance(self_instance.sagemaker_session, PipelineSession): run_func_params = inspect.signature(run_func).parameters arg_list = list(args) diff --git a/tests/conftest.py b/tests/conftest.py index 25f594a74b..011937f027 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,7 +24,7 @@ from sagemaker import Session, image_uris, utils from sagemaker.local import LocalSession -from sagemaker.workflow.pipeline_context import PipelineSession +from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession DEFAULT_REGION = "us-west-2" CUSTOM_BUCKET_NAME_PREFIX = "sagemaker-custom-bucket" @@ -162,6 +162,11 @@ def pipeline_session(boto_session): return PipelineSession(boto_session=boto_session) +@pytest.fixture(scope="session") +def local_pipeline_session(boto_session): + return LocalPipelineSession(boto_session=boto_session) + + @pytest.fixture(scope="module") def custom_bucket_name(boto_session): region = boto_session.region_name diff --git a/tests/data/mxnet_mnist/code/evaluation.py b/tests/data/mxnet_mnist/code/evaluation.py new file mode 100644 index 0000000000..2e6790e433 --- /dev/null +++ b/tests/data/mxnet_mnist/code/evaluation.py @@ -0,0 +1,13 @@ +import json +import pathlib + +if __name__ == "__main__": + # use static value for model evaluation metrics + report_dict = {"metrics": {"f1": {"value": 0.7}, "mse": {"value": 5.8}}} + + output_dir = "/opt/ml/processing/evaluation" + pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) + + evaluation_path = f"{output_dir}/evaluation.json" + with open(evaluation_path, "w") as f: + f.write(json.dumps(report_dict)) diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index 0c3bf140c3..7a8b23df42 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -28,9 +28,19 @@ from sagemaker import image_uris +from sagemaker.model import Model +from sagemaker.transformer import Transformer +from sagemaker.inputs import CreateModelInput from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor from sagemaker.sklearn.processing import SKLearnProcessor - +from sagemaker.workflow.pipeline import Pipeline +from sagemaker.workflow.steps import TrainingStep, ProcessingStep, TransformStep, CreateModelStep +from sagemaker.workflow.parameters import ParameterInteger +from sagemaker.workflow.condition_step import ConditionStep +from sagemaker.workflow.fail_step import FailStep +from sagemaker.workflow.conditions import ConditionLessThanOrEqualTo +from sagemaker.workflow.functions import JsonGet, PropertyFile, Join +from sagemaker.workflow.pipeline_context import LocalPipelineSession from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient from sagemaker.mxnet import MXNet @@ -59,6 +69,25 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, self.local_mode = True +class LocalPipelineNoS3Session(LocalPipelineSession): + """ + This Session sets local_code: True regardless of any config file settings + """ + + def __init__(self): + super(LocalPipelineSession, self).__init__() + + def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, **kwargs): + self.boto_session = boto3.Session(region_name=DEFAULT_REGION) + if self.config is None: + self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}} + + self._region_name = DEFAULT_REGION + self.sagemaker_client = LocalSagemakerClient(self) + self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config) + self.local_mode = True + + @pytest.fixture(scope="module") def sagemaker_local_session_no_local_code(boto_session): return LocalSession(boto_session=boto_session, disable_local_code=True) @@ -316,7 +345,7 @@ def test_local_transform_mxnet( cpu_instance_type, ): data_path = os.path.join(DATA_DIR, "mxnet_mnist") - script_path = os.path.join(data_path, "mnist.py") + script_path = os.path.join(data_path, "check_env.py") mx = MXNet( entry_point=script_path, @@ -326,6 +355,7 @@ def test_local_transform_mxnet( framework_version=mxnet_inference_latest_version, py_version=mxnet_inference_latest_py_version, sagemaker_session=sagemaker_local_session, + environment={"MYVAR": "HELLO_WORLD"}, ) train_input = mx.sagemaker_session.upload_data( @@ -462,3 +492,236 @@ def test_local_processing_script_processor(sagemaker_local_session, sklearn_imag assert job_description["AppSpecification"]["ImageUri"] == sklearn_image_uri assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"} + + +@pytest.mark.local_mode +def test_local_pipeline_with_processing_step(sklearn_latest_version, local_pipeline_session): + sklearn_processor = SKLearnProcessor( + framework_version=sklearn_latest_version, + role="SageMakerRole", + instance_type="local", + instance_count=1, + command=["python3"], + sagemaker_session=local_pipeline_session, + ) + script_path = os.path.join(DATA_DIR, "dummy_script.py") + input_file_path = os.path.join(DATA_DIR, "dummy_input.txt") + processing_args = sklearn_processor.run( + code=script_path, + inputs=[ProcessingInput(source=input_file_path, destination="/opt/ml/processing/inputs/")], + ) + processing_step = ProcessingStep( + name="sklearn_processor_local_pipeline", step_args=processing_args + ) + pipeline = Pipeline( + name="local_pipeline_processing", + steps=[processing_step], + sagemaker_session=local_pipeline_session, + ) + pipeline.create("SageMakerRole", "pipeline for sdk integ testing") + + with lock.lock(LOCK_PATH): + execution = pipeline.start() + + pipeline_execution_describe_result = execution.describe() + assert pipeline_execution_describe_result["PipelineArn"] == "local_pipeline_processing" + assert pipeline_execution_describe_result["PipelineExecutionStatus"] == "Succeeded" + + pipeline_execution_list_steps_result = execution.list_steps() + assert len(pipeline_execution_list_steps_result["PipelineExecutionSteps"]) == 1 + assert ( + pipeline_execution_list_steps_result["PipelineExecutionSteps"][0]["StepName"] + == "sklearn_processor_local_pipeline" + ) + assert ( + pipeline_execution_list_steps_result["PipelineExecutionSteps"][0]["StepStatus"] + == "Succeeded" + ) + + +@pytest.mark.local_mode +def test_local_pipeline_with_training_and_transform_steps( + mxnet_training_latest_version, + mxnet_inference_latest_version, + mxnet_training_latest_py_version, + tmpdir, +): + instance_count = ParameterInteger(name="InstanceCountParam") + session = LocalPipelineNoS3Session() + data_path = os.path.join(DATA_DIR, "mxnet_mnist") + script_path = os.path.join(data_path, "check_env.py") + output_path = "file://%s" % (str(tmpdir)) + + # define Estimator + mx = MXNet( + entry_point=script_path, + role="SageMakerRole", + instance_count=instance_count, + instance_type="local", + framework_version=mxnet_training_latest_version, + py_version=mxnet_training_latest_py_version, + sagemaker_session=session, + output_path=output_path, + environment={"MYVAR": "HELLO_WORLD"}, + ) + + # define training step + train_input = "file://" + os.path.join(data_path, "train") + test_input = "file://" + os.path.join(data_path, "test") + training_args = mx.fit({"train": train_input, "test": test_input}) + training_step = TrainingStep(name="mxnet_mnist_training", step_args=training_args) + + # define model + inference_image_uri = image_uris.retrieve( + framework="mxnet", + region=DEFAULT_REGION, + version=mxnet_inference_latest_version, + instance_type="local", + image_scope="inference", + ) + model = Model( + image_uri=inference_image_uri, + model_data=training_step.properties.ModelArtifacts.S3ModelArtifacts, + sagemaker_session=session, + role="SageMakerRole", + ) + + # define create model step + inputs = CreateModelInput( + instance_type="local", + accelerator_type="local", + ) + create_model_step = CreateModelStep( + name="mxnet_mnist_model", + model=model, + inputs=inputs, + ) + + # define transformer + transformer = Transformer( + model_name=create_model_step.properties.ModelName, + instance_type="local", + instance_count=instance_count, + output_path=output_path, + assemble_with="Line", + max_payload=1, + strategy="SingleRecord", + sagemaker_session=session, + ) + + # define transform step + transform_input = "file://" + os.path.join(data_path, "transform") + transform_args = transformer.transform( + transform_input, content_type="text/csv", split_type="Line" + ) + transform_step = TransformStep(name="mxnet_mnist_transform", step_args=transform_args) + + pipeline = Pipeline( + name="local_pipeline_training_transform", + parameters=[instance_count], + steps=[training_step, create_model_step, transform_step], + sagemaker_session=session, + ) + + pipeline.create("SageMakerRole", "pipeline for sdk integ testing") + + with lock.lock(LOCK_PATH): + execution = pipeline.start(parameters={"InstanceCountParam": 1}) + + assert os.path.exists(os.path.join(str(tmpdir), "model.tar.gz")) + assert os.path.exists(os.path.join(str(tmpdir), "data.csv.out")) + + pipeline_execution_describe_result = execution.describe() + assert pipeline_execution_describe_result["PipelineArn"] == "local_pipeline_training_transform" + assert pipeline_execution_describe_result["PipelineExecutionStatus"] == "Succeeded" + + pipeline_execution_list_steps_result = execution.list_steps() + assert len(pipeline_execution_list_steps_result["PipelineExecutionSteps"]) == 3 + + +@pytest.mark.local_mode +def test_local_pipeline_with_eval_cond_fail_steps(sklearn_image_uri, local_pipeline_session): + processor = ScriptProcessor( + image_uri=sklearn_image_uri, + role="SageMakerRole", + instance_count=1, + instance_type="local", + sagemaker_session=local_pipeline_session, + command=["python3"], + ) + + evaluation_report = PropertyFile( + name="EvaluationReport", output_name="evaluation", path="evaluation.json" + ) + + base_dir = os.path.join(DATA_DIR, "mxnet_mnist") + mx_mnist_model_data = os.path.join(base_dir, "model.tar.gz") + test_input = os.path.join(base_dir, "test") + + eval_step = ProcessingStep( + name="mxnet_mnist_eval", + processor=processor, + inputs=[ + ProcessingInput( + source=mx_mnist_model_data, + destination="/opt/ml/processing/model", + ), + ProcessingInput( + source=test_input, + destination="/opt/ml/processing/test", + ), + ], + outputs=[ + ProcessingOutput(output_name="evaluation", source="/opt/ml/processing/evaluation"), + ], + code=os.path.join(base_dir, "code/evaluation.py"), + property_files=[evaluation_report], + ) + + f1_score = JsonGet( + step_name=eval_step.name, + property_file=evaluation_report, + json_path="metrics.f1.value", + ) + + fail_step = FailStep( + name="mxnet_mnist_fail", error_message=Join(on=":", values=["F1 score too low", f1_score]) + ) + + cond_lte = ConditionLessThanOrEqualTo( + left=f1_score, + right=0.8, + ) + cond_step = ConditionStep( + name="mxnet_mnist_condition", + conditions=[cond_lte], + if_steps=[fail_step], + else_steps=[], + ) + + pipeline = Pipeline( + name="local_pipeline_training_transform", + steps=[eval_step, cond_step], + sagemaker_session=local_pipeline_session, + ) + + pipeline.create("SageMakerRole", "pipeline for sdk integ testing") + + with lock.lock(LOCK_PATH): + execution = pipeline.start() + + pipeline_execution_describe_result = execution.describe() + assert pipeline_execution_describe_result["PipelineArn"] == "local_pipeline_training_transform" + assert pipeline_execution_describe_result["PipelineExecutionStatus"] == "Failed" + + pipeline_execution_list_steps_result = execution.list_steps() + assert len(pipeline_execution_list_steps_result["PipelineExecutionSteps"]) == 3 + for step in pipeline_execution_list_steps_result["PipelineExecutionSteps"]: + if step["StepName"] == "mxnet_mnist_eval": + assert step["StepStatus"] == "Succeeded" + elif step["StepName"] == "mxnet_mnist_condition": + assert step["StepStatus"] == "Succeeded" + assert step["Metadata"]["Condition"]["Outcome"] is True + else: + assert step["StepStatus"] == "Failed" + assert step["FailureReason"] == "F1 score too low:0.7" diff --git a/tests/unit/sagemaker/local/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py index 4fdc589f8c..9ba32e1785 100644 --- a/tests/unit/sagemaker/local/test_local_entities.py +++ b/tests/unit/sagemaker/local/test_local_entities.py @@ -20,10 +20,9 @@ from botocore.exceptions import ClientError import sagemaker.local -from sagemaker.model import Model from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline -from sagemaker.workflow.steps import CreateModelStep +from sagemaker.workflow.lambda_step import LambdaStep from tests.unit.sagemaker.workflow.helpers import CustomStep @@ -230,10 +229,7 @@ def test_start_local_pipeline(mock_local_pipeline_executor, sagemaker_local_sess def test_start_local_pipeline_with_unsupported_step_type(sagemaker_local_session): - step = CreateModelStep( - name="MyRegisterModelStep", - model=Model(image_uri="mock_image_uri"), - ) + step = LambdaStep(name="MyLambdaStep", lambda_func=Mock()) pipeline = Pipeline( name="MyPipeline", parameters=[], diff --git a/tests/unit/sagemaker/local/test_local_pipeline.py b/tests/unit/sagemaker/local/test_local_pipeline.py index 7ffa17d774..5ff050d0db 100644 --- a/tests/unit/sagemaker/local/test_local_pipeline.py +++ b/tests/unit/sagemaker/local/test_local_pipeline.py @@ -51,6 +51,7 @@ _StepExecutorFactory, _TrainingStepExecutor, _TransformStepExecutor, + _CreateModelStepExecutor, LocalPipelineExecutor, StepExecutionException, ) @@ -387,12 +388,12 @@ def test_evaluate_json_get_function( execution = _LocalPipelineExecution("my-execution", pipeline) execution.step_execution["inputProcessingStep"].properties = { "ProcessingOutputConfig": { - "Outputs": [ - { + "Outputs": { + "TestOutputName": { "OutputName": "TestOutputName", "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, } - ] + } } } evaluated_args = LocalPipelineExecutor( @@ -472,12 +473,12 @@ def test_evaluate_json_get_function_s3_client_error(read_s3_file, local_sagemake execution = _LocalPipelineExecution("my-execution", pipeline) execution.step_execution["inputProcessingStep"].properties = { "ProcessingOutputConfig": { - "Outputs": [ - { + "Outputs": { + "TestOutputName": { "OutputName": "TestOutputName", "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, } - ] + } } } with pytest.raises(StepExecutionException) as e: @@ -523,12 +524,12 @@ def test_evaluate_json_get_function_bad_json_in_property_file( execution = _LocalPipelineExecution("my-execution", pipeline) execution.step_execution["inputProcessingStep"].properties = { "ProcessingOutputConfig": { - "Outputs": [ - { + "Outputs": { + "TestOutputName": { "OutputName": "TestOutputName", "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, } - ] + } } } with pytest.raises(StepExecutionException) as e: @@ -573,12 +574,12 @@ def test_evaluate_json_get_function_invalid_json_path(read_s3_file, local_sagema execution = _LocalPipelineExecution("my-execution", pipeline) execution.step_execution["inputProcessingStep"].properties = { "ProcessingOutputConfig": { - "Outputs": [ - { + "Outputs": { + "TestOutputName": { "OutputName": "TestOutputName", "S3Output": {"S3Uri": "s3://my-bucket/processing/output"}, } - ] + } } } with pytest.raises(StepExecutionException) as e: @@ -594,6 +595,7 @@ def test_evaluate_json_get_function_invalid_json_path(read_s3_file, local_sagema (Mock(step_type=StepTypeEnum.TRANSFORM), _TransformStepExecutor), (Mock(step_type=StepTypeEnum.CONDITION), _ConditionStepExecutor), (Mock(step_type=StepTypeEnum.FAIL), _FailStepExecutor), + (Mock(step_type=StepTypeEnum.CREATE_MODEL), _CreateModelStepExecutor), ], ) def test_step_executor_factory(step, step_executor_class): diff --git a/tests/unit/sagemaker/local/test_local_utils.py b/tests/unit/sagemaker/local/test_local_utils.py index b2a94740b4..0129e574ea 100644 --- a/tests/unit/sagemaker/local/test_local_utils.py +++ b/tests/unit/sagemaker/local/test_local_utils.py @@ -153,15 +153,15 @@ def test_get_using_dot_notation(json_path, expected): def test_get_using_dot_notation_type_error(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): sagemaker.local.utils.get_using_dot_notation({"foo": "bar"}, "foo.test") def test_get_using_dot_notation_key_error(): - with pytest.raises(KeyError): + with pytest.raises(ValueError): sagemaker.local.utils.get_using_dot_notation({"foo": {"bar": 1}}, "foo.test") def test_get_using_dot_notation_index_error(): - with pytest.raises(IndexError): + with pytest.raises(ValueError): sagemaker.local.utils.get_using_dot_notation({"foo": ["bar"]}, "foo[1]") From 9bd44812efa43f423f6fb47ac10256fe0df7fe53 Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Mon, 1 Aug 2022 14:17:08 -0700 Subject: [PATCH 172/526] fix: pipelines local mode minor bug fixes Co-authored-by: Namrata Madan --- src/sagemaker/local/entities.py | 49 ++++++++++--------- src/sagemaker/local/local_session.py | 1 + src/sagemaker/local/pipeline.py | 21 +++----- src/sagemaker/workflow/pipeline.py | 32 +++++++----- tests/integ/test_local_mode.py | 21 +++----- .../sagemaker/local/test_local_session.py | 11 ++++- .../sagemaker/workflow/test_pipeline_graph.py | 23 +++++---- 7 files changed, 83 insertions(+), 75 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 3ee0c41e28..8229a7fbac 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -669,10 +669,8 @@ def start(self, **kwargs): execution = _LocalPipelineExecution(execution_id, self.pipeline, **kwargs) self._executions[execution_id] = execution - logger.info( - "Starting execution for pipeline %s. Execution ID is %s", - self.pipeline.name, - execution_id, + print( + f"Starting execution for pipeline {self.pipeline.name}. Execution ID is {execution_id}" ) self.last_modified_time = datetime.datetime.now().timestamp() @@ -690,6 +688,8 @@ def __init__( PipelineExecutionDescription=None, PipelineExecutionDisplayName=None, ): + from sagemaker.workflow.pipeline import PipelineGraph + self.pipeline = pipeline self.pipeline_execution_name = execution_id self.pipeline_execution_description = PipelineExecutionDescription @@ -699,7 +699,8 @@ def __init__( self.creation_time = datetime.datetime.now().timestamp() self.last_modified_time = self.creation_time self.step_execution = {} - self._initialize_step_execution(self.pipeline.steps) + self.pipeline_dag = PipelineGraph.from_pipeline(self.pipeline) + self._initialize_step_execution(self.pipeline_dag.step_map.values()) self.pipeline_parameters = self._initialize_and_validate_parameters(PipelineParameters) self._blocked_steps = {} @@ -732,37 +733,36 @@ def update_execution_success(self): """Mark execution as succeeded.""" self.status = _LocalExecutionStatus.SUCCEEDED.value self.last_modified_time = datetime.datetime.now().timestamp() - logger.info("Pipeline execution %s SUCCEEDED", self.pipeline_execution_name) + print(f"Pipeline execution {self.pipeline_execution_name} SUCCEEDED") def update_execution_failure(self, step_name, failure_message): """Mark execution as failed.""" self.status = _LocalExecutionStatus.FAILED.value - self.failure_reason = f"Step {step_name} failed with message: {failure_message}" + self.failure_reason = f"Step '{step_name}' failed with message: {failure_message}" self.last_modified_time = datetime.datetime.now().timestamp() - logger.info( - "Pipeline execution %s FAILED because step %s failed.", - self.pipeline_execution_name, - step_name, + print( + f"Pipeline execution {self.pipeline_execution_name} FAILED because step " + f"'{step_name}' failed." ) def update_step_properties(self, step_name, step_properties): """Update pipeline step execution output properties.""" self.step_execution.get(step_name).update_step_properties(step_properties) - logger.info("Pipeline step %s SUCCEEDED.", step_name) + print(f"Pipeline step '{step_name}' SUCCEEDED.") def update_step_failure(self, step_name, failure_message): """Mark step_name as failed.""" + print(f"Pipeline step '{step_name}' FAILED. Failure message is: {failure_message}") self.step_execution.get(step_name).update_step_failure(failure_message) - logger.info("Pipeline step %s FAILED. Failure message is: %s", step_name, failure_message) def mark_step_executing(self, step_name): """Update pipelines step's status to EXECUTING and start_time to now.""" - logger.info("Starting pipeline step: %s", step_name) + print(f"Starting pipeline step: '{step_name}'") self.step_execution.get(step_name).mark_step_executing() def _initialize_step_execution(self, steps): """Initialize step_execution dict.""" - from sagemaker.workflow.steps import StepTypeEnum + from sagemaker.workflow.steps import StepTypeEnum, Step supported_steps_types = ( StepTypeEnum.TRAINING, @@ -774,16 +774,17 @@ def _initialize_step_execution(self, steps): ) for step in steps: - if step.step_type not in supported_steps_types: - error_msg = self._construct_validation_exception_message( - "Step type {} is not supported in local mode.".format(step.step_type.value) + if isinstance(step, Step): + if step.step_type not in supported_steps_types: + error_msg = self._construct_validation_exception_message( + "Step type {} is not supported in local mode.".format(step.step_type.value) + ) + raise ClientError(error_msg, "start_pipeline_execution") + self.step_execution[step.name] = _LocalPipelineExecutionStep( + step.name, step.step_type, step.description, step.display_name ) - raise ClientError(error_msg, "start_pipeline_execution") - self.step_execution[step.name] = _LocalPipelineExecutionStep( - step.name, step.step_type, step.description, step.display_name - ) - if step.step_type == StepTypeEnum.CONDITION: - self._initialize_step_execution(step.if_steps + step.else_steps) + if step.step_type == StepTypeEnum.CONDITION: + self._initialize_step_execution(step.if_steps + step.else_steps) def _initialize_and_validate_parameters(self, overridden_parameters): """Initialize and validate pipeline parameters.""" diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index d785abb154..2168e90357 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -448,6 +448,7 @@ def update_pipeline( } raise ClientError(error_response, "update_pipeline") LocalSagemakerClient._pipelines[pipeline.name].pipeline_description = pipeline_description + LocalSagemakerClient._pipelines[pipeline.name].pipeline = pipeline LocalSagemakerClient._pipelines[ pipeline.name ].last_modified_time = datetime.now().timestamp() diff --git a/src/sagemaker/local/pipeline.py b/src/sagemaker/local/pipeline.py index 0c77f2b967..c9305b795c 100644 --- a/src/sagemaker/local/pipeline.py +++ b/src/sagemaker/local/pipeline.py @@ -17,11 +17,12 @@ import json from copy import deepcopy from datetime import datetime -from typing import Dict, List +from typing import Dict, List, Union from botocore.exceptions import ClientError from sagemaker.workflow.conditions import ConditionTypeEnum from sagemaker.workflow.steps import StepTypeEnum, Step +from sagemaker.workflow.step_collections import StepCollection from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.parameters import Parameter from sagemaker.workflow.functions import Join, JsonGet, PropertyFile @@ -256,8 +257,7 @@ def execute(self): return self.pipline_executor.local_sagemaker_client.describe_training_job(job_name) except Exception as e: # pylint: disable=W0703 self.pipline_executor.execution.update_step_failure( - self.step.name, - f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + self.step.name, f"{type(e).__name__}: {str(e)}" ) @@ -291,8 +291,7 @@ def execute(self): except Exception as e: # pylint: disable=W0703 self.pipline_executor.execution.update_step_failure( - self.step.name, - f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + self.step.name, f"{type(e).__name__}: {str(e)}" ) @@ -300,12 +299,10 @@ class _ConditionStepExecutor(_StepExecutor): """Executor class to execute ConditionStep locally""" def execute(self): - def _block_all_downstream_steps(steps: List[Step]): + def _block_all_downstream_steps(steps: List[Union[Step, StepCollection]]): steps_to_block = set() for step in steps: - steps_to_block.update( - self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step.name) - ) + steps_to_block.update(self.pipline_executor.pipeline_dag.get_steps_in_sub_dag(step)) self.pipline_executor._blocked_steps.update(steps_to_block) if_steps = self.step.if_steps @@ -469,8 +466,7 @@ def execute(self): return self.pipline_executor.local_sagemaker_client.describe_transform_job(job_name) except Exception as e: # pylint: disable=W0703 self.pipline_executor.execution.update_step_failure( - self.step.name, - f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + self.step.name, f"{type(e).__name__}: {str(e)}" ) @@ -485,8 +481,7 @@ def execute(self): return self.pipline_executor.local_sagemaker_client.describe_model(model_name) except Exception as e: # pylint: disable=W0703 self.pipline_executor.execution.update_step_failure( - self.step.name, - f"Error when executing step {self.step.name} of type {type(self.step)}: {e}", + self.step.name, f"{type(e).__name__}: {str(e)}" ) diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index 51848f0386..1b3a74c467 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -26,7 +26,6 @@ from sagemaker import s3 from sagemaker._studio import _append_project_tags from sagemaker.session import Session -from sagemaker.local import LocalSession from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep from sagemaker.workflow.entities import ( @@ -162,9 +161,7 @@ def _create_args( # If pipeline definition is large, upload to S3 bucket and # provide PipelineDefinitionS3Location to request instead. - if len(pipeline_definition.encode("utf-8")) < 1024 * 100 or isinstance( - self.sagemaker_session, LocalSession - ): + if len(pipeline_definition.encode("utf-8")) < 1024 * 100: kwargs["PipelineDefinition"] = pipeline_definition else: desired_s3_uri = s3.s3_path_join( @@ -660,19 +657,28 @@ def is_cyclic_helper(current_step): return True return False - def get_steps_in_sub_dag(self, current_step: str, steps: Set[str] = None) -> Set[str]: + def get_steps_in_sub_dag( + self, current_step: Union[Step, StepCollection], sub_dag_steps: Set[str] = None + ) -> Set[str]: """Get names of all steps (including current step) in the sub dag of current step. Returns a set of step names in the sub dag. """ - if steps is None: - steps = set() - if current_step not in self.adjacency_list: - raise ValueError("Step: %s does not exist in the pipeline." % current_step) - steps.add(current_step) - for step in self.adjacency_list[current_step]: - self.get_steps_in_sub_dag(step, steps) - return steps + if sub_dag_steps is None: + sub_dag_steps = set() + + if isinstance(current_step, StepCollection): + current_steps = current_step.steps + else: + current_steps = [current_step] + + for step in current_steps: + if step.name not in self.adjacency_list: + raise ValueError("Step: %s does not exist in the pipeline." % step.name) + sub_dag_steps.add(step.name) + for sub_step in self.adjacency_list[step.name]: + self.get_steps_in_sub_dag(self.step_map.get(sub_step), sub_dag_steps) + return sub_dag_steps def __iter__(self): """Perform topological sort traversal of the Pipeline Graph.""" diff --git a/tests/integ/test_local_mode.py b/tests/integ/test_local_mode.py index 7a8b23df42..e772091896 100644 --- a/tests/integ/test_local_mode.py +++ b/tests/integ/test_local_mode.py @@ -30,11 +30,11 @@ from sagemaker.model import Model from sagemaker.transformer import Transformer -from sagemaker.inputs import CreateModelInput from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor from sagemaker.sklearn.processing import SKLearnProcessor from sagemaker.workflow.pipeline import Pipeline -from sagemaker.workflow.steps import TrainingStep, ProcessingStep, TransformStep, CreateModelStep +from sagemaker.workflow.steps import TrainingStep, ProcessingStep, TransformStep +from sagemaker.workflow.model_step import ModelStep from sagemaker.workflow.parameters import ParameterInteger from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.fail_step import FailStep @@ -546,8 +546,8 @@ def test_local_pipeline_with_training_and_transform_steps( mxnet_training_latest_py_version, tmpdir, ): - instance_count = ParameterInteger(name="InstanceCountParam") session = LocalPipelineNoS3Session() + instance_count = ParameterInteger(name="InstanceCountParam") data_path = os.path.join(DATA_DIR, "mxnet_mnist") script_path = os.path.join(data_path, "check_env.py") output_path = "file://%s" % (str(tmpdir)) @@ -587,19 +587,12 @@ def test_local_pipeline_with_training_and_transform_steps( ) # define create model step - inputs = CreateModelInput( - instance_type="local", - accelerator_type="local", - ) - create_model_step = CreateModelStep( - name="mxnet_mnist_model", - model=model, - inputs=inputs, - ) + model_step_args = model.create(instance_type="local", accelerator_type="local") + model_step = ModelStep(name="mxnet_mnist_model", step_args=model_step_args) # define transformer transformer = Transformer( - model_name=create_model_step.properties.ModelName, + model_name=model_step.properties.ModelName, instance_type="local", instance_count=instance_count, output_path=output_path, @@ -619,7 +612,7 @@ def test_local_pipeline_with_training_and_transform_steps( pipeline = Pipeline( name="local_pipeline_training_transform", parameters=[instance_count], - steps=[training_step, create_model_step, transform_step], + steps=[training_step, model_step, transform_step], sagemaker_session=session, ) diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index cdae087c00..728c7e0c06 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -887,16 +887,25 @@ def test_create_describe_update_pipeline(): steps=[CustomStep(name="MyStep", input_data=parameter)], sagemaker_session=LocalSession(), ) + definition = pipeline.definition() pipeline.create("dummy-role", "pipeline-description") pipeline_describe_response1 = pipeline.describe() assert pipeline_describe_response1["PipelineArn"] == "MyPipeline" - assert pipeline_describe_response1["PipelineDefinition"] == pipeline.definition() + assert pipeline_describe_response1["PipelineDefinition"] == definition assert pipeline_describe_response1["PipelineDescription"] == "pipeline-description" + pipeline = Pipeline( + name="MyPipeline", + parameters=[parameter], + steps=[CustomStep(name="MyStepUpdated", input_data=parameter)], + sagemaker_session=LocalSession(), + ) + updated_definition = pipeline.definition() pipeline.update("dummy-role", "pipeline-description-2") pipeline_describe_response2 = pipeline.describe() assert pipeline_describe_response2["PipelineDescription"] == "pipeline-description-2" + assert pipeline_describe_response2["PipelineDefinition"] == updated_definition assert ( pipeline_describe_response2["CreationTime"] != pipeline_describe_response2["LastModifiedTime"] diff --git a/tests/unit/sagemaker/workflow/test_pipeline_graph.py b/tests/unit/sagemaker/workflow/test_pipeline_graph.py index c0c27eebee..d1a749b783 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline_graph.py +++ b/tests/unit/sagemaker/workflow/test_pipeline_graph.py @@ -61,7 +61,7 @@ def pipeline_graph_get_sub_dag(sagemaker_session_mock): step_f = CustomStep(name="stepF", depends_on=[step_c]) step_g = CustomStep(name="stepG", depends_on=[step_e, step_d]) step_h = CustomStep(name="stepH", depends_on=[step_g]) - step_i = CustomStep(name="stepI", depends_on=[step_h]) + step_i = CustomStepCollection(name="stepI", depends_on=[step_h]) step_j = CustomStep(name="stepJ", depends_on=[step_h]) pipeline = Pipeline( @@ -312,7 +312,8 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock): "stepF", "stepG", "stepH", - "stepI", + "stepI-0", + "stepI-1", "stepJ", }, ), @@ -326,22 +327,24 @@ def test_pipeline_graph_cyclic(sagemaker_session_mock): "stepF", "stepG", "stepH", - "stepI", + "stepI-0", + "stepI-1", "stepJ", }, ), - ("stepC", {"stepC", "stepE", "stepF", "stepG", "stepH", "stepI", "stepJ"}), - ("stepD", {"stepD", "stepG", "stepH", "stepI", "stepJ"}), - ("stepE", {"stepE", "stepG", "stepH", "stepI", "stepJ"}), + ("stepC", {"stepC", "stepE", "stepF", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}), + ("stepD", {"stepD", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}), + ("stepE", {"stepE", "stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}), ("stepF", {"stepF"}), - ("stepG", {"stepG", "stepH", "stepI", "stepJ"}), - ("stepH", {"stepH", "stepI", "stepJ"}), - ("stepI", {"stepI"}), + ("stepG", {"stepG", "stepH", "stepI-0", "stepI-1", "stepJ"}), + ("stepH", {"stepH", "stepI-0", "stepI-1", "stepJ"}), + ("stepI", {"stepI-0", "stepI-1"}), ("stepJ", {"stepJ"}), ], ) def test_get_steps_in_sub_dag(pipeline_graph_get_sub_dag, step_name, expected_steps): - sub_steps = pipeline_graph_get_sub_dag.get_steps_in_sub_dag(step_name) + step = pipeline_graph_get_sub_dag.step_map.get(step_name) + sub_steps = pipeline_graph_get_sub_dag.get_steps_in_sub_dag(step) assert sub_steps == expected_steps From cf044a96202ec09929fb352ef3d1805958a0793c Mon Sep 17 00:00:00 2001 From: Namrata Madan Date: Mon, 1 Aug 2022 17:46:04 -0700 Subject: [PATCH 173/526] fix: yaml safe_load sagemaker config Co-authored-by: Namrata Madan --- src/sagemaker/local/local_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 2168e90357..e791e03b4e 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -667,7 +667,7 @@ def _initialize( logger.error(_module_import_error("yaml", "Local mode", "local")) raise e - self.config = yaml.load(open(sagemaker_config_file, "r")) + self.config = yaml.safe_load(open(sagemaker_config_file, "r")) if self._disable_local_code and "local" in self.config: self.config["local"]["local_code"] = False From 6a27b01937dcef2c08c0e2fa48c3e5f929ebb209 Mon Sep 17 00:00:00 2001 From: stacicho Date: Mon, 8 Aug 2022 10:36:30 -0700 Subject: [PATCH 174/526] documentation: New content for Pipelines local mode --- doc/Makefile | 2 +- ...azon_sagemaker_model_building_pipeline.rst | 48 + doc/doc_utils/pretrainedmodels.rst | 2076 +++++++++++++++++ doc/overview.rst | 82 +- doc/requirements.txt | 1 + .../sagemaker.workflow.pipelines.rst | 3 + src/sagemaker/workflow/pipeline_context.py | 25 +- 7 files changed, 2229 insertions(+), 8 deletions(-) diff --git a/doc/Makefile b/doc/Makefile index af378c2e0f..d64cda3268 100644 --- a/doc/Makefile +++ b/doc/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = -W -SPHINXBUILD = python -msphinx +SPHINXBUILD = python3 -msphinx SPHINXPROJ = sagemaker SOURCEDIR = . BUILDDIR = _build diff --git a/doc/amazon_sagemaker_model_building_pipeline.rst b/doc/amazon_sagemaker_model_building_pipeline.rst index b85a9d9251..0be93d869f 100644 --- a/doc/amazon_sagemaker_model_building_pipeline.rst +++ b/doc/amazon_sagemaker_model_building_pipeline.rst @@ -97,6 +97,54 @@ When you use :class:`sagemaker.workflow.pipeline_context.PipelineSession` rather .. warning:: A :class:`sagemaker.workflow.pipeline_context.PipelineSession` must be given in order to start the job during pipeline execution time. Otherwise, a training job will get started immediately. +Local Pipeline Session +====================== + +Like Pipeline Session, Local Pipeline Session provides a convenient way to capture input job arguments without starting the job. These input arguments can be provided in the :code:`step_args` parameter to their corresponding `Pipelines step type `__. The difference between :class:`sagemaker.workflow.pipeline_context.PipelineSession` and :class:`sagemaker.workflow.pipeline_context.LocalPipelineSession` is that :class:`sagemaker.workflow.pipeline_context.LocalPipelineSession` is used to run SageMaker pipelines locally (in local mode) whereas using :class:`sagemaker.workflow.pipeline_context.PipelineSession` runs the job on the managed service. + +.. code-block:: python + + from sagemaker.workflow.pipeline_context import LocalPipelineSession + + local_pipeline_session = LocalPipelineSession() + + pytorch_estimator = PyTorch( + sagemaker_session=local_pipeline_session, + role=sagemaker.get_execution_role(), + instance_type="ml.c5.xlarge", + instance_count=1, + framework_version="1.8.0", + py_version="py36", + entry_point="./entry_point.py", + ) + + step = TrainingStep( + name="MyTrainingStep", + step_args=pytorch_estimator.fit( + inputs=TrainingInput(s3_data="s3://my-bucket/my-data/train"), + ) + ) + + pipeline = Pipeline( + name="MyPipeline", + steps=[step], + sagemaker_session=local_pipeline_session + ) + + pipeline.create( + role_arn=sagemaker.get_execution_role(), + description="local pipeline example" + ) + + // pipeline will execute locally + pipeline.start() + + steps = pipeline.list_steps() + + training_job_name = steps['PipelineExecutionSteps'][0]['Metadata']['TrainingJob']['Arn'] + + step_outputs = pipeline_session.sagemaker_client.describe_training_job(TrainingJobName = training_job_name) + Pipeline Parameters ====================== diff --git a/doc/doc_utils/pretrainedmodels.rst b/doc/doc_utils/pretrainedmodels.rst index e69de29bb2..3aee847787 100644 --- a/doc/doc_utils/pretrainedmodels.rst +++ b/doc/doc_utils/pretrainedmodels.rst @@ -0,0 +1,2076 @@ +.. |external-link| raw:: html + + + +================================================ +Built-in Algorithms with pre-trained Model Table +================================================ + + The SageMaker Python SDK uses model IDs and model versions to access the necessary + utilities for pre-trained models. This table serves to provide the core material plus + some extra information that can be useful in selecting the correct model ID and + corresponding parameters. + + If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. + We highly suggest pinning an exact model version however. + + These models are also available through the + `JumpStart UI in SageMaker Studio `__ + +.. list-table:: Available Models + :widths: 50 20 20 20 30 20 + :header-rows: 1 + :class: datatable + + * - Model ID + - Fine Tunable? + - Latest Version + - Min SDK Version + - Problem Type + - Source + * - autogluon-classification-ensemble + - True + - 1.0.1 + - 2.80.0 + - Classification + - `GluonCV `__ |external-link| + * - autogluon-regression-ensemble + - True + - 1.0.1 + - 2.80.0 + - Regression + - `GluonCV `__ |external-link| + * - catboost-classification-model + - True + - 1.2.4 + - 2.75.0 + - Classification + - `Catboost `__ |external-link| + * - catboost-regression-model + - True + - 1.2.4 + - 2.75.0 + - Regression + - `Catboost `__ |external-link| + * - huggingface-eqa-bert-base-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-multilingual-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-multilingual-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-base-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-cased-whole-word-masking + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-bert-large-uncased-whole-word-masking + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-multilingual-cased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilbert-base-uncased + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-distilroberta-base + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-base + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-base-openai-detector + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-eqa-roberta-large + - True + - 1.0.2 + - 2.75.0 + - Question Answering + - `HuggingFace `__ |external-link| + * - huggingface-ner-distilbert-base-cased-finetuned-conll03-english + - False + - 1.1.0 + - 2.75.0 + - Named Entity Recognition + - `HuggingFace `__ |external-link| + * - huggingface-ner-distilbert-base-uncased-finetuned-conll03-english + - False + - 1.1.0 + - 2.75.0 + - Named Entity Recognition + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-multilingual-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-cased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-bert-large-uncased-whole-word-masking + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-multilingual-cased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilbert-base-uncased + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-distilroberta-base + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-base + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-base-openai-detector + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-large + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-roberta-large-openai-detector + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-clm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-ende-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-enro-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-tlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-spc-xlm-mlm-xnli15-1024 + - True + - 1.2.3 + - 2.75.0 + - Sentence Pair Classification + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bart-large-cnn-samsum + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bert-small2bert-small-finetuned-cnn-daily-mail-summarization + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bigbird-pegasus-large-arxiv + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-bigbird-pegasus-large-pubmed + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-cnn-12-6 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-cnn-6-6 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-xsum-1-1 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-summarization-distilbart-xsum-12-3 + - False + - 1.1.0 + - 2.75.0 + - Text Summarization + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-distilgpt2 + - False + - 1.1.0 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-textgeneration-gpt2 + - False + - 1.1.0 + - 2.75.0 + - Text Generation + - `HuggingFace `__ |external-link| + * - huggingface-translation-opus-mt-en-es + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-opus-mt-en-vi + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-base + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-large + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - huggingface-translation-t5-small + - False + - 1.1.0 + - 2.75.0 + - Machine Translation + - `HuggingFace `__ |external-link| + * - lightgbm-classification-model + - True + - 1.2.3 + - 2.75.0 + - Classification + - `LightGBM `__ |external-link| + * - lightgbm-regression-model + - True + - 1.2.3 + - 2.75.0 + - Regression + - `LightGBM `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet101-v1d-coco + - False + - 1.1.0 + - 2.75.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet18-v1b-coco + - False + - 1.1.0 + - 2.75.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-fpn-resnet50-v1b-coco + - False + - 1.1.0 + - 2.75.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-is-mask-rcnn-resnet18-v1b-coco + - False + - 1.1.0 + - 2.75.0 + - Instance Segmentation + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-fpn-resnet101-v1d-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-fpn-resnet50-v1b-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet101-v1d-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet50-v1b-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-faster-rcnn-resnet50-v1b-voc + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-300-vgg16-atrous-coco + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-300-vgg16-atrous-voc + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-mobilenet1-0-coco + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-mobilenet1-0-voc + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-resnet50-v1-coco + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-resnet50-v1-voc + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-vgg16-atrous-coco + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-ssd-512-vgg16-atrous-voc + - True + - 1.2.3 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-darknet53-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-darknet53-voc + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-mobilenet1-0-coco + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-od-yolo3-mobilenet1-0-voc + - False + - 1.1.0 + - 2.75.0 + - Object Detection + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-ade + - True + - 1.3.5 + - 2.75.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-coco + - True + - 1.3.5 + - 2.75.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet101-voc + - True + - 1.3.5 + - 2.75.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-semseg-fcn-resnet50-ade + - True + - 1.3.5 + - 2.75.0 + - Semantic Segmentation + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-base-uncased + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-base-wiki-uncased + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-large-uncased + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - mxnet-tcembedding-robertafin-large-wiki-uncased + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `GluonCV `__ |external-link| + * - pytorch-eqa-bert-base-cased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-multilingual-cased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-multilingual-uncased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-base-uncased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased-whole-word-masking + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-cased-whole-word-masking-finetuned-squad + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased-whole-word-masking + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-bert-large-uncased-whole-word-masking-finetuned-squad + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-cased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-multilingual-cased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilbert-base-uncased + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-distilroberta-base + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-base + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-base-openai-detector + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-large + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-eqa-roberta-large-openai-detector + - True + - 1.2.0 + - 2.75.0 + - Question Answering + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-alexnet + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet121 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet161 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet169 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-densenet201 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-googlenet + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-mobilenet-v2 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet101 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet152 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet18 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet34 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnet50 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnext101-32x8d + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-resnext50-32x4d + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-shufflenet-v2-x1-0 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-squeezenet1-0 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-squeezenet1-1 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg11 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg11-bn + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg13 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg13-bn + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg16 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg16-bn + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg19 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-vgg19-bn + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-wide-resnet101-2 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-ic-wide-resnet50-2 + - True + - 2.2.3 + - 2.75.0 + - Image Classification + - `Pytorch Hub `__ |external-link| + * - pytorch-od-nvidia-ssd + - False + - 1.0.1 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-mobilenet-v3-large-320-fpn + - False + - 1.0.0 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-mobilenet-v3-large-fpn + - False + - 1.0.0 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-od1-fasterrcnn-resnet50-fpn + - True + - 1.3.2 + - 2.75.0 + - Object Detection + - `Pytorch Hub `__ |external-link| + * - pytorch-tabtransformerclassification-model + - True + - 1.0.1 + - 2.75.0 + - Source + - `Source `__ |external-link| + * - pytorch-tabtransformerregression-model + - True + - 1.0.0 + - 2.75.0 + - Source + - `Source `__ |external-link| + * - sklearn-classification-linear + - True + - 1.1.1 + - 2.75.0 + - Classification + - `ScikitLearn `__ |external-link| + * - sklearn-regression-linear + - True + - 1.1.1 + - 2.75.0 + - Regression + - `ScikitLearn `__ |external-link| + * - tensorflow-ic-bit-m-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x1-imagenet21k-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r101x3-imagenet21k-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x1-imagenet21k-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-m-r50x3-imagenet21k-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r101x1-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r101x3-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r50x1-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-bit-s-r50x3-ilsvrc2012-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b0-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b1-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b2-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b3-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b4-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b5-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b6-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-b7-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite0-classification-2 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite1-classification-2 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite2-classification-2 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite3-classification-2 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-efficientnet-lite4-classification-2 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-resnet-v2-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v1-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v2-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-inception-v3-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-128-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-160-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-192-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-025-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-128-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-160-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-192-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-050-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-128-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-160-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-192-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-075-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-128-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-160-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-192-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v1-100-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-035-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-050-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-075-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-100-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-130-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-mobilenet-v2-140-224-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-101-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-152-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v1-50-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-101-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-152-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-imagenet-resnet-v2-50-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-resnet-50-classification-1 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-tf2-preview-inception-v3-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-ic-tf2-preview-mobilenet-v2-classification-4 + - True + - 2.0.1 + - 2.80.0 + - Image Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r101x3-imagenet21k-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-m-r50x3-imagenet21k-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r101x1-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r101x3-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r50x1-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-bit-s-r50x3-ilsvrc2012-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b0-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b1-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b2-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b3-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-b6-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite0-featurevector-2 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite1-featurevector-2 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite2-featurevector-2 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite3-featurevector-2 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-efficientnet-lite4-featurevector-2 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v1-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v2-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-inception-v3-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-128-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-160-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-192-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-025-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-128-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-160-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-192-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-050-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-128-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-160-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-192-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-075-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-128-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-160-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-192-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v1-100-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-035-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-050-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-075-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-100-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-130-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-mobilenet-v2-140-224-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-101-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-152-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v1-50-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-101-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-152-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-imagenet-resnet-v2-50-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-resnet-50-featurevector-1 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-tf2-preview-inception-v3-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-icembedding-tf2-preview-mobilenet-v2-featurevector-4 + - False + - 2.0.0 + - 2.80.0 + - Image Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-1024x1024-kpts-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-512x512-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-hourglass-512x512-kpts-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet101v1-fpn-512x512-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v1-fpn-512x512-kpts-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v2-512x512-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-centernet-resnet50v2-512x512-kpts-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d0-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d1-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d2-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d3-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d4-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-efficientdet-d5-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-inception-resnet-v2-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-inception-resnet-v2-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet101-v1-800x1333-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet152-v1-800x1333-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-faster-rcnn-resnet50-v1-800x1333-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet101-v1-fpn-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet101-v1-fpn-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet152-v1-fpn-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet152-v1-fpn-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet50-v1-fpn-1024x1024-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-retinanet-resnet50-v1-fpn-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v1-fpn-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-2 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-320x320-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-od-ssd-mobilenet-v2-fpnlite-640x640-1 + - False + - 2.0.0 + - 2.80.0 + - Object Detection + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-electra-base-1 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-electra-small-1 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-experts-bert-pubmed-1 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-spc-experts-bert-wiki-books-1 + - True + - 1.2.2 + - 2.75.0 + - Sentence Pair Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-cased-L-12-H-768-A-12-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-cased-L-24-H-1024-A-16-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-wwm-cased-L-24-H-1024-A-16-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-en-wwm-uncased-L-24-H-1024-A-16-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-bert-multi-cased-L-12-H-768-A-12-2 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-electra-base-1 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-electra-small-1 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-experts-bert-pubmed-1 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tc-experts-bert-wiki-books-1 + - True + - 1.1.2 + - 2.75.0 + - Text Classification + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-128-A-2-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-256-A-4-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-10-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-128-A-2-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-256-A-4 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-12-H-768-A-12-4 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-128-A-2-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-256-A-4 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-2-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-128-A-2-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-256-A-4-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-4-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-128-A-2-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-256-A-4 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-6-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-256-A-4-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-512-A-8-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-en-uncased-L-8-H-768-A-12-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-wiki-books-mnli-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-bert-wiki-books-sst2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-base-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-talkheads-ggelu-bert-en-large-2 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-base-1 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - tensorflow-tcembedding-universal-sentence-encoder-cmlm-en-large-1 + - False + - 1.1.0 + - 2.75.0 + - Text Embedding + - `Tensorflow Hub `__ |external-link| + * - xgboost-classification-model + - True + - 1.2.1 + - 2.75.0 + - Classification + - `XGBoost `__ |external-link| + * - xgboost-regression-model + - True + - 1.2.1 + - 2.75.0 + - Regression + - `XGBoost `__ |external-link| diff --git a/doc/overview.rst b/doc/overview.rst index a6deb7b988..f6ac8a8a2d 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -1370,9 +1370,10 @@ For more details about what can be specified here, see `API docs `__. + +Here is an end-to-end example: + +.. code:: python + + from sagemaker.workflow.pipeline import Pipeline + from sagemaker.workflow.steps import TrainingStep, TransformStep + from sagemaker.workflow.model_step import ModelStep + from sagemaker.workflow.pipeline_context import LocalPipelineSession + from sagemaker.mxnet import MXNet + from sagemaker.model import Model + from sagemaker.inputs import TranformerInput + from sagemaker.transformer import Transformer + + session = LocalPipelineSession() + mxnet_estimator = MXNet('train.py', + role='SageMakerRole', + instance_type='local', + instance_count=1, + framework_version='1.2.1', + sagemaker_session=session) + + train_step_args = mxnet_estimator.fit('file:///tmp/my_training_data') + + # Define training step + train_step = TrainingStep(name='local_mxnet_train', step_args=train_step_args) + + model = Model( + image_uri=inference_image_uri, + model_data=train_step.properties.ModelArtifacts.S3ModelArtifacts, + sagemaker_session=session, + role='SageMakerRole' + ) + + # Define create model step + model_step_args = model.create(instance_type="local", accelerator_type="local") + model_step = ModelStep( + name='local_mxnet_model', + step_args=model_step_args + ) + + transformer = Transformer( + model_name=model_step.properties.ModelName, + instance_type='local', + instance_count=1, + sagemaker_session=session + ) + transform_args = transformer.transform('file:///tmp/my_transform_data') + # Define transform step + transform_step = TransformStep(name='local_mxnet_transform', step_args=transform_args) + + # Define the pipeline + pipeline = Pipeline(name='local_pipeline', + steps=[train_step, model_step, transform_step], + sagemaker_session=session) + + # Create the pipeline + pipeline.upsert(role_arn='SageMakerRole', description='local pipeline example') + + # Start a pipeline execution + execution = pipeline.start() + +.. note:: + Currently Pipelines Local Mode only supports the following step types: Training, Processing, Transform, Model (with Create Model arguments only), Condition, and Fail. + + For detailed examples of running Docker in local mode, see: - `TensorFlow local mode example notebook `__. -- `MXNet local mode CPU example notebook `__. -- `MXNet local mode GPU example notebook `__. +- `MXNet local mode example notebook `__. - `PyTorch local mode example notebook `__. You can also find these notebooks in the **SageMaker Python SDK** section of the **SageMaker Examples** section in a notebook instance. diff --git a/doc/requirements.txt b/doc/requirements.txt index e844537584..21c94775d5 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -2,3 +2,4 @@ sphinx==3.1.2 sphinx-rtd-theme==0.5.0 docutils==0.15.2 packaging==20.9 +jinja2<3.1 diff --git a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst index 47e14b5e85..f115b032c4 100644 --- a/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst +++ b/doc/workflows/pipelines/sagemaker.workflow.pipelines.rst @@ -91,6 +91,9 @@ Pipeline Context .. autoclass:: sagemaker.workflow.pipeline_context.PipelineSession :members: +.. autoclass:: sagemaker.workflow.pipeline_context.LocalPipelineSession + :members: + Parallelism Configuration ------------------------- diff --git a/src/sagemaker/workflow/pipeline_context.py b/src/sagemaker/workflow/pipeline_context.py index bcf4cbe2d5..6dc3622b3d 100644 --- a/src/sagemaker/workflow/pipeline_context.py +++ b/src/sagemaker/workflow/pipeline_context.py @@ -153,11 +153,34 @@ def init_model_step_arguments(self, model): class LocalPipelineSession(LocalSession, PipelineSession): - """Class representing a local session for SageMaker Pipelines executions.""" + """Managing a session that executes Sagemaker pipelines and jobs locally in a pipeline context. + + This class inherits from the LocalSession and PipelineSession classes. + When running Sagemaker pipelines locally, this class is preferred over LocalSession. + """ def __init__( self, boto_session=None, default_bucket=None, s3_endpoint_url=None, disable_local_code=False ): + """Initialize a ``LocalPipelineSession``. + + Args: + boto_session (boto3.session.Session): The underlying Boto3 session which AWS service + calls are delegated to (default: None). If not provided, one is created with + default AWS configuration chain. + default_bucket (str): The default Amazon S3 bucket to be used by this session. + This will be created the next time an Amazon S3 bucket is needed (by calling + :func:`default_bucket`). + If not provided, a default bucket will be created based on the following format: + "sagemaker-{region}-{aws-account-id}". + Example: "sagemaker-my-custom-bucket". + s3_endpoint_url (str): Override the default endpoint URL for Amazon S3, + if set (default: None). + disable_local_code (bool): Set to True to override the default AWS configuration chain + to disable the `local.local_code` setting, which may not be supported for some SDK + features (default: False). + """ + super().__init__( boto_session=boto_session, default_bucket=default_bucket, From 9d2ab4194ac41d23df0157a51c3b7fe55c5e76a4 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 17 Aug 2022 00:13:18 +0000 Subject: [PATCH 175/526] prepare release v2.104.0 --- CHANGELOG.md | 24 ++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 878fc2a6e5..acdd80d328 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,29 @@ # Changelog +## v2.104.0 (2022-08-17) + +### Features + + * local mode executor implementation + * Pipelines local mode setup + * Add PT 1.12 support + * added _AnalysisConfigGenerator for clarify + +### Bug Fixes and Other Changes + + * yaml safe_load sagemaker config + * pipelines local mode minor bug fixes + * add local mode integ tests + * implement local JsonGet function + * Add Pipeline annotation in model base class and tensorflow estimator + * Allow users to customize trial component display names for pipeline launched jobs + * Update localmode code to decode urllib response as UTF8 + +### Documentation Changes + + * New content for Pipelines local mode + * Correct documentation error + ## v2.103.0 (2022-08-05) ### Features diff --git a/VERSION b/VERSION index ef1a140e63..5245cdbd1b 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.103.1.dev0 +2.104.0 From 702bef7fa929a665f3bb09b1990f7e8f51483cef Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 17 Aug 2022 00:13:19 +0000 Subject: [PATCH 176/526] update development version to v2.104.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 5245cdbd1b..c574defb95 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.104.0 +2.104.1.dev0 From f92765271936d50f08e3b7c8d01ec01718efa956 Mon Sep 17 00:00:00 2001 From: ragdhall <42956678+ragdhall@users.noreply.github.com> Date: Tue, 16 Aug 2022 20:35:28 -0400 Subject: [PATCH 177/526] documentation: update first-party algorithms and structural updates (#3300) * documentation: update first-party algorithms and structural updates * documentation: minor style changes to pass codebuild * documentation: more minor edits Co-authored-by: Raghav Dhall --- doc/algorithms/index.rst | 19 +++----- doc/algorithms/other/index.rst | 10 +++++ .../sagemaker.amazon.amazon_estimator.rst | 0 doc/algorithms/tabular/autogluon.rst | 28 ++++++++++++ doc/algorithms/tabular/catboost.rst | 37 ++++++++++++++++ .../{ => tabular}/factorization_machines.rst | 2 +- doc/algorithms/tabular/index.rst | 18 ++++++++ doc/algorithms/{ => tabular}/knn.rst | 0 doc/algorithms/tabular/lightgbm.rst | 28 ++++++++++++ .../{ => tabular}/linear_learner.rst | 0 doc/algorithms/{ => tabular}/object2vec.rst | 0 doc/algorithms/tabular/tabtransformer.rst | 28 ++++++++++++ doc/algorithms/tabular/xgboost.rst | 40 +++++++++++++++++ doc/algorithms/text/blazing_text.rst | 27 ++++++++++++ doc/algorithms/text/index.rst | 22 ++++++++++ doc/algorithms/{ => text}/lda.rst | 0 .../text/machine_translation_hugging_face.rst | 10 +++++ .../named_entity_recognition_hugging_face.rst | 10 +++++ doc/algorithms/{ => text}/ntm.rst | 0 .../text/question_answering_pytorch.rst | 9 ++++ ...tence_pair_classification_hugging_face.rst | 9 ++++ ...entence_pair_classification_tensorflow.rst | 9 ++++ doc/algorithms/text/sequence_to_sequence.rst | 14 ++++++ .../text/text_classification_tensorflow.rst | 9 ++++ .../text/text_embedding_tensorflow_mxnet.rst | 9 ++++ .../text/text_generation_hugging_face.rst | 9 ++++ .../text/text_summarization_hugging_face.rst | 9 ++++ doc/algorithms/time_series/deep_ar.rst | 11 +++++ doc/algorithms/time_series/index.rst | 10 +++++ doc/algorithms/unsupervised/index.rst | 13 ++++++ .../{ => unsupervised}/ipinsights.rst | 0 doc/algorithms/{ => unsupervised}/kmeans.rst | 0 doc/algorithms/{ => unsupervised}/pca.rst | 0 .../{ => unsupervised}/randomcutforest.rst | 0 .../vision/image_classification_mxnet.rst | 18 ++++++++ .../vision/image_classification_pytorch.rst | 9 ++++ .../image_classification_tensorflow.rst | 9 ++++ .../vision/image_embedding_tensorflow.rst | 9 ++++ doc/algorithms/vision/index.rst | 20 +++++++++ .../vision/instance_segmentation_mxnet.rst | 9 ++++ .../vision/object_detection_mxnet.rst | 9 ++++ .../vision/object_detection_mxnet_gluoncv.rst | 19 ++++++++ .../vision/object_detection_pytorch.rst | 9 ++++ .../vision/object_detection_tensorflow.rst | 9 ++++ .../vision/semantic_segmentation_mxnet.rst | 9 ++++ .../semantic_segmentation_mxnet_gluoncv.rst | 43 +++++++++++++++++++ doc/doc_utils/jumpstart_doc_utils.py | 1 + doc/index.rst | 2 +- doc/overview.rst | 2 + 49 files changed, 553 insertions(+), 14 deletions(-) create mode 100644 doc/algorithms/other/index.rst rename doc/algorithms/{ => other}/sagemaker.amazon.amazon_estimator.rst (100%) create mode 100644 doc/algorithms/tabular/autogluon.rst create mode 100644 doc/algorithms/tabular/catboost.rst rename doc/algorithms/{ => tabular}/factorization_machines.rst (97%) create mode 100644 doc/algorithms/tabular/index.rst rename doc/algorithms/{ => tabular}/knn.rst (100%) create mode 100644 doc/algorithms/tabular/lightgbm.rst rename doc/algorithms/{ => tabular}/linear_learner.rst (100%) rename doc/algorithms/{ => tabular}/object2vec.rst (100%) create mode 100644 doc/algorithms/tabular/tabtransformer.rst create mode 100644 doc/algorithms/tabular/xgboost.rst create mode 100644 doc/algorithms/text/blazing_text.rst create mode 100644 doc/algorithms/text/index.rst rename doc/algorithms/{ => text}/lda.rst (100%) create mode 100644 doc/algorithms/text/machine_translation_hugging_face.rst create mode 100644 doc/algorithms/text/named_entity_recognition_hugging_face.rst rename doc/algorithms/{ => text}/ntm.rst (100%) create mode 100644 doc/algorithms/text/question_answering_pytorch.rst create mode 100644 doc/algorithms/text/sentence_pair_classification_hugging_face.rst create mode 100644 doc/algorithms/text/sentence_pair_classification_tensorflow.rst create mode 100644 doc/algorithms/text/sequence_to_sequence.rst create mode 100644 doc/algorithms/text/text_classification_tensorflow.rst create mode 100644 doc/algorithms/text/text_embedding_tensorflow_mxnet.rst create mode 100644 doc/algorithms/text/text_generation_hugging_face.rst create mode 100644 doc/algorithms/text/text_summarization_hugging_face.rst create mode 100644 doc/algorithms/time_series/deep_ar.rst create mode 100644 doc/algorithms/time_series/index.rst create mode 100644 doc/algorithms/unsupervised/index.rst rename doc/algorithms/{ => unsupervised}/ipinsights.rst (100%) rename doc/algorithms/{ => unsupervised}/kmeans.rst (100%) rename doc/algorithms/{ => unsupervised}/pca.rst (100%) rename doc/algorithms/{ => unsupervised}/randomcutforest.rst (100%) create mode 100644 doc/algorithms/vision/image_classification_mxnet.rst create mode 100644 doc/algorithms/vision/image_classification_pytorch.rst create mode 100644 doc/algorithms/vision/image_classification_tensorflow.rst create mode 100644 doc/algorithms/vision/image_embedding_tensorflow.rst create mode 100644 doc/algorithms/vision/index.rst create mode 100644 doc/algorithms/vision/instance_segmentation_mxnet.rst create mode 100644 doc/algorithms/vision/object_detection_mxnet.rst create mode 100644 doc/algorithms/vision/object_detection_mxnet_gluoncv.rst create mode 100644 doc/algorithms/vision/object_detection_pytorch.rst create mode 100644 doc/algorithms/vision/object_detection_tensorflow.rst create mode 100644 doc/algorithms/vision/semantic_segmentation_mxnet.rst create mode 100644 doc/algorithms/vision/semantic_segmentation_mxnet_gluoncv.rst diff --git a/doc/algorithms/index.rst b/doc/algorithms/index.rst index 45235a3bfe..bd78267d9b 100644 --- a/doc/algorithms/index.rst +++ b/doc/algorithms/index.rst @@ -1,5 +1,5 @@ ###################### -First-Party Algorithms +Built-in Algorithms ###################### Amazon SageMaker provides implementations of some common machine learning algorithms optimized for GPU architecture and massive datasets. @@ -7,14 +7,9 @@ Amazon SageMaker provides implementations of some common machine learning algori .. toctree:: :maxdepth: 2 - sagemaker.amazon.amazon_estimator - factorization_machines - ipinsights - kmeans - knn - lda - linear_learner - ntm - object2vec - pca - randomcutforest + tabular/index + text/index + time_series/index + unsupervised/index + vision/index + other/index diff --git a/doc/algorithms/other/index.rst b/doc/algorithms/other/index.rst new file mode 100644 index 0000000000..4bd9800221 --- /dev/null +++ b/doc/algorithms/other/index.rst @@ -0,0 +1,10 @@ +###################### +Other +###################### + +:ref:`All Pre-trained Models ` + +.. toctree:: + :maxdepth: 2 + + sagemaker.amazon.amazon_estimator diff --git a/doc/algorithms/sagemaker.amazon.amazon_estimator.rst b/doc/algorithms/other/sagemaker.amazon.amazon_estimator.rst similarity index 100% rename from doc/algorithms/sagemaker.amazon.amazon_estimator.rst rename to doc/algorithms/other/sagemaker.amazon.amazon_estimator.rst diff --git a/doc/algorithms/tabular/autogluon.rst b/doc/algorithms/tabular/autogluon.rst new file mode 100644 index 0000000000..8eae72187e --- /dev/null +++ b/doc/algorithms/tabular/autogluon.rst @@ -0,0 +1,28 @@ +############ +AutoGluon +############ + +`AutoGluon-Tabular `__ is a popular open-source AutoML framework that trains highly accurate machine learning models on an unprocessed tabular dataset. +Unlike existing AutoML frameworks that primarily focus on model and hyperparameter selection, AutoGluon-Tabular succeeds by ensembling multiple models and stacking them in multiple layers. + + +The following table outlines a variety of sample notebooks that address different use cases of Amazon SageMaker AutoGluon-Tabular algorithm. + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Notebook Title + - Description + * - `Tabular classification with Amazon SageMaker AutoGluon-Tabular algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker AutoGluon-Tabular algorithm to train and host a tabular classification model. + * - `Tabular regression with Amazon SageMaker AutoGluon-Tabular algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker AutoGluon-Tabular algorithm to train and host a tabular regression model. + + +For instructions on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see +`Use Amazon SageMaker Notebook Instances `__. After you have created a notebook +instance and opened it, choose the SageMaker Examples tab to see a list of all of the SageMaker samples. To open a notebook, choose its +Use tab and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker AutoGluon-Tabular Algorithm `__. diff --git a/doc/algorithms/tabular/catboost.rst b/doc/algorithms/tabular/catboost.rst new file mode 100644 index 0000000000..e7c72aa5c4 --- /dev/null +++ b/doc/algorithms/tabular/catboost.rst @@ -0,0 +1,37 @@ +############ +CatBoost +############ + + +`CatBoost `__ is a popular and high-performance open-source implementation of the Gradient Boosting Decision Tree (GBDT) +algorithm. GBDT is a supervised learning algorithm that attempts to accurately predict a target variable by combining an ensemble of +estimates from a set of simpler and weaker models. + +CatBoost introduces two critical algorithmic advances to GBDT: + +* The implementation of ordered boosting, a permutation-driven alternative to the classic algorithm + +* An innovative algorithm for processing categorical features + +Both techniques were created to fight a prediction shift caused by a special kind of target leakage present in all currently existing +implementations of gradient boosting algorithms. + +The following table outlines a variety of sample notebooks that address different use cases of Amazon SageMaker CatBoost algorithm. + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Notebook Title + - Description + * - `Tabular classification with Amazon SageMaker LightGBM and CatBoost algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker CatBoost algorithm to train and host a tabular classification model. + * - `Tabular regression with Amazon SageMaker LightGBM and CatBoost algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker CatBoost algorithm to train and host a tabular regression model. + +For instructions on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see +`Use Amazon SageMaker Notebook Instances `__. After you have created a notebook +instance and opened it, choose the SageMaker Examples tab to see a list of all of the SageMaker samples. To open a notebook, choose its +Use tab and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker CatBoost Algorithm `__. diff --git a/doc/algorithms/factorization_machines.rst b/doc/algorithms/tabular/factorization_machines.rst similarity index 97% rename from doc/algorithms/factorization_machines.rst rename to doc/algorithms/tabular/factorization_machines.rst index e6a509d167..77997702bf 100644 --- a/doc/algorithms/factorization_machines.rst +++ b/doc/algorithms/tabular/factorization_machines.rst @@ -1,4 +1,4 @@ -FactorizationMachines +Factorization Machines ------------------------- The Amazon SageMaker Factorization Machines algorithm. diff --git a/doc/algorithms/tabular/index.rst b/doc/algorithms/tabular/index.rst new file mode 100644 index 0000000000..029437fb39 --- /dev/null +++ b/doc/algorithms/tabular/index.rst @@ -0,0 +1,18 @@ +###################### +Tabular +###################### + +Amazon SageMaker provides built-in algorithms that are tailored to the analysis of tabular data. The built-in SageMaker algorithms for tabular data can be used for either classification or regression problems. + +.. toctree:: + :maxdepth: 2 + + autogluon + catboost + factorization_machines + knn + lightgbm + linear_learner + tabtransformer + xgboost + object2vec diff --git a/doc/algorithms/knn.rst b/doc/algorithms/tabular/knn.rst similarity index 100% rename from doc/algorithms/knn.rst rename to doc/algorithms/tabular/knn.rst diff --git a/doc/algorithms/tabular/lightgbm.rst b/doc/algorithms/tabular/lightgbm.rst new file mode 100644 index 0000000000..176b10cdba --- /dev/null +++ b/doc/algorithms/tabular/lightgbm.rst @@ -0,0 +1,28 @@ +############ +LightGBM +############ + +`LightGBM `__ is a popular and efficient open-source implementation of the Gradient Boosting +Decision Tree (GBDT) algorithm. GBDT is a supervised learning algorithm that attempts to accurately predict a target variable by +combining an ensemble of estimates from a set of simpler and weaker models. LightGBM uses additional techniques to significantly improve +the efficiency and scalability of conventional GBDT. + +The following table outlines a variety of sample notebooks that address different use cases of Amazon SageMaker LightGBM algorithm. + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Notebook Title + - Description + * - `Tabular classification with Amazon SageMaker LightGBM and CatBoost algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker LightGBM algorithm to train and host a tabular classification model. + * - `Tabular regression with Amazon SageMaker LightGBM and CatBoost algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker LightGBM algorithm to train and host a tabular regression model. + +For instructions on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see +`Use Amazon SageMaker Notebook Instances `__. After you have created a notebook +instance and opened it, choose the SageMaker Examples tab to see a list of all of the SageMaker samples. To open a notebook, choose its +Use tab and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker LightGBM Algorithm `__. diff --git a/doc/algorithms/linear_learner.rst b/doc/algorithms/tabular/linear_learner.rst similarity index 100% rename from doc/algorithms/linear_learner.rst rename to doc/algorithms/tabular/linear_learner.rst diff --git a/doc/algorithms/object2vec.rst b/doc/algorithms/tabular/object2vec.rst similarity index 100% rename from doc/algorithms/object2vec.rst rename to doc/algorithms/tabular/object2vec.rst diff --git a/doc/algorithms/tabular/tabtransformer.rst b/doc/algorithms/tabular/tabtransformer.rst new file mode 100644 index 0000000000..facebfcd83 --- /dev/null +++ b/doc/algorithms/tabular/tabtransformer.rst @@ -0,0 +1,28 @@ +############### +TabTransformer +############### + +`TabTransformer `__ is a novel deep tabular data modeling architecture for supervised learning. The TabTransformer architecture is built on self-attention-based Transformers. +The Transformer layers transform the embeddings of categorical features into robust contextual embeddings to achieve higher prediction accuracy. Furthermore, the contextual embeddings learned from TabTransformer +are highly robust against both missing and noisy data features, and provide better interpretability. + + +The following table outlines a variety of sample notebooks that address different use cases of Amazon SageMaker TabTransformer algorithm. + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Notebook Title + - Description + * - `Tabular classification with Amazon SageMaker TabTransformer algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker TabTransformer algorithm to train and host a tabular classification model. + * - `Tabular regression with Amazon SageMaker TabTransformer algorithm `__ + - This notebook demonstrates the use of the Amazon SageMaker TabTransformer algorithm to train and host a tabular regression model. + +For instructions on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see +`Use Amazon SageMaker Notebook Instances `__. After you have created a notebook +instance and opened it, choose the SageMaker Examples tab to see a list of all of the SageMaker samples. To open a notebook, choose its +Use tab and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker TabTransformer Algorithm `__. diff --git a/doc/algorithms/tabular/xgboost.rst b/doc/algorithms/tabular/xgboost.rst new file mode 100644 index 0000000000..829af00ac5 --- /dev/null +++ b/doc/algorithms/tabular/xgboost.rst @@ -0,0 +1,40 @@ +############ +XGBoost +############ + +The `XGBoost `__ (eXtreme Gradient Boosting) is a popular and efficient open-source implementation of the gradient boosted trees algorithm. Gradient boosting is a supervised learning algorithm that attempts to accurately predict a target variable +by combining an ensemble of estimates from a set of simpler and weaker models. The XGBoost algorithm performs well in machine learning competitions because of its robust handling of a variety of data types, relationships, distributions, and the variety of hyperparameters that you can +fine-tune. You can use XGBoost for regression, classification (binary and multiclass), and ranking problems. + +You can use the new release of the XGBoost algorithm either as a Amazon SageMaker built-in algorithm or as a framework to run training scripts in your local environments. This implementation has a smaller memory footprint, better logging, improved hyperparameter validation, and +an expanded set of metrics than the original versions. It provides an XGBoost estimator that executes a training script in a managed XGBoost environment. The current release of SageMaker XGBoost is based on the original XGBoost versions 1.0, 1.2, 1.3, and 1.5. + +The following table outlines a variety of sample notebooks that address different use cases of Amazon SageMaker XGBoost algorithm. + +.. list-table:: + :widths: 25 25 + :header-rows: 1 + + * - Notebook Title + - Description + * - `How to Create a Custom XGBoost container? `__ + - This notebook shows you how to build a custom XGBoost Container with Amazon SageMaker Batch Transform. + * - `Regression with XGBoost using Parquet `__ + - This notebook shows you how to use the Abalone dataset in Parquet to train a XGBoost model. + * - `How to Train and Host a Multiclass Classification Model? `__ + - This notebook shows how to use the MNIST dataset to train and host a multiclass classification model. + * - `How to train a Model for Customer Churn Prediction? `__ + - This notebook shows you how to train a model to Predict Mobile Customer Departure in an effort to identify unhappy customers. + * - `An Introduction to Amazon SageMaker Managed Spot infrastructure for XGBoost Training `__ + - This notebook shows you how to use Spot Instances for training with a XGBoost Container. + * - `How to use Amazon SageMaker Debugger to debug XGBoost Training Jobs? `__ + - This notebook shows you how to use Amazon SageMaker Debugger to monitor training jobs to detect inconsistencies. + * - `How to use Amazon SageMaker Debugger to debug XGBoost Training Jobs in Real-Time? `__ + - This notebook shows you how to use the MNIST dataset and Amazon SageMaker Debugger to perform real-time analysis of XGBoost training jobs while training jobs are running. + +For instructions on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see +`Use Amazon SageMaker Notebook Instances `__. After you have created a notebook +instance and opened it, choose the SageMaker Examples tab to see a list of all of the SageMaker samples. To open a notebook, choose its +Use tab and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker XGBoost Algorithm `__. diff --git a/doc/algorithms/text/blazing_text.rst b/doc/algorithms/text/blazing_text.rst new file mode 100644 index 0000000000..e42f4a0cc2 --- /dev/null +++ b/doc/algorithms/text/blazing_text.rst @@ -0,0 +1,27 @@ +############# +Blazing Text +############# + + +The Amazon SageMaker BlazingText algorithm provides highly optimized implementations of the Word2vec and text classification algorithms. The Word2vec algorithm is useful for many downstream natural language processing (NLP) +tasks, such as sentiment analysis, named entity recognition, machine translation, etc. Text classification is an important task for applications that perform web searches, information retrieval, ranking, and document classification. + +The Word2vec algorithm maps words to high-quality distributed vectors. The resulting vector representation of a word is called a word embedding. Words that are semantically similar correspond to vectors that are close together. +That way, word embeddings capture the semantic relationships between words. + +Many natural language processing (NLP) applications learn word embeddings by training on large collections of documents. These pretrained vector representations provide information about semantics and word distributions that +typically improves the generalizability of other models that are later trained on a more limited amount of data. Most implementations of the Word2vec algorithm are not optimized for multi-core CPU architectures. This makes it +difficult to scale to large datasets. + +With the BlazingText algorithm, you can scale to large datasets easily. Similar to Word2vec, it provides the Skip-gram and continuous bag-of-words (CBOW) training architectures. BlazingText's implementation of the supervised +multi-class, multi-label text classification algorithm extends the fastText text classifier to use GPU acceleration with custom `CUDA `__ + +kernels. You can train a model on more than a billion words in a couple of minutes using a multi-core CPU or a GPU. And, you achieve performance on par with the state-of-the-art deep learning text classification algorithms. + +The BlazingText algorithm is not parallelizable. For more information on parameters related to training, see `Docker Registry Paths for SageMaker Built-in Algorithms `__. + +For a sample notebook that uses the SageMaker BlazingText algorithm to train and deploy supervised binary and multiclass classification models, see +`Blazing Text classification on the DBPedia dataset `__. +For instructions for creating and accessing Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. +After creating and opening a notebook instance, choose the SageMaker Examples tab to see a list of all the SageMaker examples. The topic modeling example notebooks that use the Blazing Text are located in the Introduction to Amazon +algorithms section. To open a notebook, choose its Use tab, then choose Create copy. diff --git a/doc/algorithms/text/index.rst b/doc/algorithms/text/index.rst new file mode 100644 index 0000000000..a24288fdc7 --- /dev/null +++ b/doc/algorithms/text/index.rst @@ -0,0 +1,22 @@ +###################### +Text +###################### + +Amazon SageMaker provides algorithms that are tailored to the analysis of textual documents used in natural language processing, document classification or summarization, topic modeling or classification, and language transcription or translation. + +.. toctree:: + :maxdepth: 2 + + blazing_text + lda + ntm + sequence_to_sequence + text_classification_tensorflow + sentence_pair_classification_tensorflow + sentence_pair_classification_hugging_face + question_answering_pytorch + named_entity_recognition_hugging_face + text_summarization_hugging_face + text_generation_hugging_face + machine_translation_hugging_face + text_embedding_tensorflow_mxnet diff --git a/doc/algorithms/lda.rst b/doc/algorithms/text/lda.rst similarity index 100% rename from doc/algorithms/lda.rst rename to doc/algorithms/text/lda.rst diff --git a/doc/algorithms/text/machine_translation_hugging_face.rst b/doc/algorithms/text/machine_translation_hugging_face.rst new file mode 100644 index 0000000000..d533d0e64d --- /dev/null +++ b/doc/algorithms/text/machine_translation_hugging_face.rst @@ -0,0 +1,10 @@ +##################################### +Machine Translation - HuggingFace +##################################### + + +This is a supervised machine translation algorithm which supports many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Machine Translation for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `. diff --git a/doc/algorithms/text/named_entity_recognition_hugging_face.rst b/doc/algorithms/text/named_entity_recognition_hugging_face.rst new file mode 100644 index 0000000000..fc0fbd212c --- /dev/null +++ b/doc/algorithms/text/named_entity_recognition_hugging_face.rst @@ -0,0 +1,10 @@ +######################################## +Named Entity Recognition - HuggingFace +######################################## + +This is a supervised named entity recognition algorithm which supports fine-tuning of many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Named Entity Recognition for using these algorithms. + +For detailed documentation please refer `Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `__ + diff --git a/doc/algorithms/ntm.rst b/doc/algorithms/text/ntm.rst similarity index 100% rename from doc/algorithms/ntm.rst rename to doc/algorithms/text/ntm.rst diff --git a/doc/algorithms/text/question_answering_pytorch.rst b/doc/algorithms/text/question_answering_pytorch.rst new file mode 100644 index 0000000000..9d9d74ccb1 --- /dev/null +++ b/doc/algorithms/text/question_answering_pytorch.rst @@ -0,0 +1,9 @@ +##################################### +Question Answering - PyTorch +##################################### + +This is a supervised question answering algorithm which supports fine-tuning of many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Question Answering for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/text/sentence_pair_classification_hugging_face.rst b/doc/algorithms/text/sentence_pair_classification_hugging_face.rst new file mode 100644 index 0000000000..2892b9d516 --- /dev/null +++ b/doc/algorithms/text/sentence_pair_classification_hugging_face.rst @@ -0,0 +1,9 @@ +############################################ +Sentence Pair Classification - HuggingFace +############################################ + +This is a supervised sentence pair classification algorithm which supports fine-tuning of many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Sentence Pair Classification for using these algorithms. + +For detailed documentation please refer `Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `__ diff --git a/doc/algorithms/text/sentence_pair_classification_tensorflow.rst b/doc/algorithms/text/sentence_pair_classification_tensorflow.rst new file mode 100644 index 0000000000..80264e84f3 --- /dev/null +++ b/doc/algorithms/text/sentence_pair_classification_tensorflow.rst @@ -0,0 +1,9 @@ +############################################ +Sentence Pair Classification - TensorFlow +############################################ + +This is a supervised sentence pair classification algorithm which supports fine-tuning of many pre-trained models available in Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Sentence Pair Classification for using these algorithms. + +For detailed documentation please refer `Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK `__ diff --git a/doc/algorithms/text/sequence_to_sequence.rst b/doc/algorithms/text/sequence_to_sequence.rst new file mode 100644 index 0000000000..00d9302a01 --- /dev/null +++ b/doc/algorithms/text/sequence_to_sequence.rst @@ -0,0 +1,14 @@ +####################### +Sequence-to-Sequence +####################### + +Amazon SageMaker Sequence to Sequence is a supervised learning algorithm where the input is a sequence of tokens (for example, text, audio) and the output generated is another sequence of tokens. Example applications include: machine +translation (input a sentence from one language and predict what that sentence would be in another language), text summarization (input a longer string of words and predict a shorter string of words that is a summary), speech-to-text +(audio clips converted into output sentences in tokens). Recently, problems in this domain have been successfully modeled with deep neural networks that show a significant performance boost over previous methodologies. Amazon SageMaker +seq2seq uses Recurrent Neural Networks (RNNs) and Convolutional Neural Network (CNN) models with attention as encoder-decoder architectures. + +For a sample notebook that shows how to use the SageMaker Sequence to Sequence algorithm to train a English-German translation model, see +`Machine Translation English-German Example Using SageMaker Seq2Seq `__. +For instructions how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. Once you have +created a notebook instance and opened it, select the SageMaker Examples tab to see a list of all the SageMaker samples. The topic modeling example notebooks using the NTM algorithms are located in the Introduction to Amazon algorithms section. +To open a notebook, click on its Use tab and select Create copy. diff --git a/doc/algorithms/text/text_classification_tensorflow.rst b/doc/algorithms/text/text_classification_tensorflow.rst new file mode 100644 index 0000000000..c60a5b3e1c --- /dev/null +++ b/doc/algorithms/text/text_classification_tensorflow.rst @@ -0,0 +1,9 @@ +################################## +Text Classification - TensorFlow +################################## + +This is a supervised text classification algorithm which supports fine-tuning of many pre-trained models available in Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Text Classification for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/text/text_embedding_tensorflow_mxnet.rst b/doc/algorithms/text/text_embedding_tensorflow_mxnet.rst new file mode 100644 index 0000000000..d015c2ef30 --- /dev/null +++ b/doc/algorithms/text/text_embedding_tensorflow_mxnet.rst @@ -0,0 +1,9 @@ +#################################### +Text Embedding - TensorFlow, MxNet +#################################### + +This is a supervised text embedding algorithm which supports many pre-trained models available in MXNet and Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Text Embedding for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/text/text_generation_hugging_face.rst b/doc/algorithms/text/text_generation_hugging_face.rst new file mode 100644 index 0000000000..30fae26196 --- /dev/null +++ b/doc/algorithms/text/text_generation_hugging_face.rst @@ -0,0 +1,9 @@ +############################################ +Text Generation - HuggingFace +############################################ + +This is a supervised text generation algorithm which supports many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Text Generation for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/text/text_summarization_hugging_face.rst b/doc/algorithms/text/text_summarization_hugging_face.rst new file mode 100644 index 0000000000..206c880ba3 --- /dev/null +++ b/doc/algorithms/text/text_summarization_hugging_face.rst @@ -0,0 +1,9 @@ +############################################ +Text Summarization - HuggingFace +############################################ + +This is a supervised text summarization algorithm which supports many pre-trained models available in Hugging Face. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Text Summarization for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/time_series/deep_ar.rst b/doc/algorithms/time_series/deep_ar.rst new file mode 100644 index 0000000000..c373cb7405 --- /dev/null +++ b/doc/algorithms/time_series/deep_ar.rst @@ -0,0 +1,11 @@ +################################## +Deep AR Forecasting +################################## + +The Amazon SageMaker DeepAR forecasting algorithm is a supervised learning algorithm for forecasting scalar (one-dimensional) time series using recurrent neural networks (RNN). Classical forecasting methods, such as autoregressive integrated moving average (ARIMA) or exponential smoothing (ETS), fit a single model to each individual time series. They then use that model to extrapolate the time series into the future. + +In many applications, however, you have many similar time series across a set of cross-sectional units. For example, you might have time series groupings for demand for different products, server loads, and requests for webpages. For this type of application, you can benefit from training a single model jointly over all of the time series. DeepAR takes this approach. When your dataset contains hundreds of related time series, DeepAR outperforms the standard ARIMA and ETS methods. You can also use the trained model to generate forecasts for new time series that are similar to the ones it has been trained on. + +The training input for the DeepAR algorithm is one or, preferably, more target time series that have been generated by the same process or similar processes. Based on this input dataset, the algorithm trains a model that learns an approximation of this process/processes and uses it to predict how the target time series evolves. Each target time series can be optionally associated with a vector of static (time-independent) categorical features provided by the cat field and a vector of dynamic (time-dependent) time series provided by the dynamic_feat field. SageMaker trains the DeepAR model by randomly sampling training examples from each target time series in the training dataset. Each training example consists of a pair of adjacent context and prediction windows with fixed predefined lengths. To control how far in the past the network can see, use the context_length hyperparameter. To control how far in the future predictions can be made, use the prediction_length hyperparameter. For more information, see `How the DeepAR Algorithm Works `__. + +For a sample notebook that shows how to prepare a time series dataset for training the SageMaker DeepAR algorithm and how to deploy the trained model for performing inferences, see `Time series forecasting with DeepAR - Synthetic data `__ as well as `DeepAR demo on electricity dataset `__, which illustrates the advanced features of DeepAR on a real world dataset. For instructions on creating and accessing Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. After creating and opening a notebook instance, choose the SageMaker Examples tab to see a list of all of the SageMaker examples. To open a notebook, choose its Use tab, and choose Create copy. diff --git a/doc/algorithms/time_series/index.rst b/doc/algorithms/time_series/index.rst new file mode 100644 index 0000000000..05d2464ccd --- /dev/null +++ b/doc/algorithms/time_series/index.rst @@ -0,0 +1,10 @@ +###################### +Time-series +###################### + +Amazon SageMaker provides algorithms that are tailored to the analysis of textual documents used in natural language processing, document classification or summarization, topic modeling or classification, and language transcription or translation. + +.. toctree:: + :maxdepth: 2 + + deep_ar diff --git a/doc/algorithms/unsupervised/index.rst b/doc/algorithms/unsupervised/index.rst new file mode 100644 index 0000000000..a3e6af9801 --- /dev/null +++ b/doc/algorithms/unsupervised/index.rst @@ -0,0 +1,13 @@ +###################### +Unsupervised +###################### + +Amazon SageMaker provides several built-in algorithms that can be used for a variety of unsupervised learning tasks such as clustering, dimension reduction, pattern recognition, and anomaly detection. + +.. toctree:: + :maxdepth: 2 + + ipinsights + kmeans + pca + randomcutforest diff --git a/doc/algorithms/ipinsights.rst b/doc/algorithms/unsupervised/ipinsights.rst similarity index 100% rename from doc/algorithms/ipinsights.rst rename to doc/algorithms/unsupervised/ipinsights.rst diff --git a/doc/algorithms/kmeans.rst b/doc/algorithms/unsupervised/kmeans.rst similarity index 100% rename from doc/algorithms/kmeans.rst rename to doc/algorithms/unsupervised/kmeans.rst diff --git a/doc/algorithms/pca.rst b/doc/algorithms/unsupervised/pca.rst similarity index 100% rename from doc/algorithms/pca.rst rename to doc/algorithms/unsupervised/pca.rst diff --git a/doc/algorithms/randomcutforest.rst b/doc/algorithms/unsupervised/randomcutforest.rst similarity index 100% rename from doc/algorithms/randomcutforest.rst rename to doc/algorithms/unsupervised/randomcutforest.rst diff --git a/doc/algorithms/vision/image_classification_mxnet.rst b/doc/algorithms/vision/image_classification_mxnet.rst new file mode 100644 index 0000000000..1550a6026c --- /dev/null +++ b/doc/algorithms/vision/image_classification_mxnet.rst @@ -0,0 +1,18 @@ +############################# +Image Classification - MxNet +############################# + +The Amazon SageMaker image classification algorithm is a supervised learning algorithm that supports multi-label classification. It takes an image as input and outputs one or more labels assigned to that image. +It uses a convolutional neural network that can be trained from scratch or trained using transfer learning when a large number of training images are not available. + +The recommended input format for the Amazon SageMaker image classification algorithms is Apache MXNet `RecordIO `__. +However, you can also use raw images in .jpg or .png format. Refer to `this discussion `__ for a broad overview of efficient +data preparation and loading for machine learning systems. + +For a sample notebook that uses the SageMaker image classification algorithm to train a model on the caltech-256 dataset and then to deploy it to perform inferences, see the +`End-to-End Multiclass Image Classification Example `__. +For instructions how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. +Once you have created a notebook instance and opened it, select the SageMaker Examples tab to see a list of all the SageMaker samples. The example image classification notebooks are located in the Introduction to Amazon +algorithms section. To open a notebook, click on its Use tab and select Create copy. + +For detailed documentation, please refer to the `Sagemaker Image Classification Algorithm `__ diff --git a/doc/algorithms/vision/image_classification_pytorch.rst b/doc/algorithms/vision/image_classification_pytorch.rst new file mode 100644 index 0000000000..3c154c6cfe --- /dev/null +++ b/doc/algorithms/vision/image_classification_pytorch.rst @@ -0,0 +1,9 @@ +############################### +Image Classification - PyTorch +############################### + +This is a supervised image clasification algorithm which supports fine-tuning of many pre-trained models available in Pytorch Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Image Classification for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/image_classification_tensorflow.rst b/doc/algorithms/vision/image_classification_tensorflow.rst new file mode 100644 index 0000000000..e49820ee50 --- /dev/null +++ b/doc/algorithms/vision/image_classification_tensorflow.rst @@ -0,0 +1,9 @@ +################################## +Image Classification - TensorFlow +################################## + +This is a supervised image clasification algorithm which supports fine-tuning of many pre-trained models available in Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Image Classification for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/image_embedding_tensorflow.rst b/doc/algorithms/vision/image_embedding_tensorflow.rst new file mode 100644 index 0000000000..0938377354 --- /dev/null +++ b/doc/algorithms/vision/image_embedding_tensorflow.rst @@ -0,0 +1,9 @@ +############################# +Image Embedding - TensorFlow +############################# + +This is a supervised image embedding algorithm which supports many pre-trained models available in Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Image Embedding for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/index.rst b/doc/algorithms/vision/index.rst new file mode 100644 index 0000000000..50af5003b1 --- /dev/null +++ b/doc/algorithms/vision/index.rst @@ -0,0 +1,20 @@ +###################### +Vision +###################### + +Amazon SageMaker provides image processing algorithms that are used for image classification, object detection, and computer vision. + +.. toctree:: + :maxdepth: 2 + + image_classification_mxnet + image_classification_pytorch + image_classification_tensorflow + object_detection_mxnet_gluoncv + object_detection_mxnet + object_detection_pytorch + object_detection_tensorflow + semantic_segmentation_mxnet_gluoncv + semantic_segmentation_mxnet + instance_segmentation_mxnet + image_embedding_tensorflow diff --git a/doc/algorithms/vision/instance_segmentation_mxnet.rst b/doc/algorithms/vision/instance_segmentation_mxnet.rst new file mode 100644 index 0000000000..a38611bc9a --- /dev/null +++ b/doc/algorithms/vision/instance_segmentation_mxnet.rst @@ -0,0 +1,9 @@ +############################## +Instance Segmentation - MXNet +############################## + +This is a supervised image segmentation algorithm which supports many pre-trained models available in MXNet. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Image Segmentation for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/object_detection_mxnet.rst b/doc/algorithms/vision/object_detection_mxnet.rst new file mode 100644 index 0000000000..9ce52f992b --- /dev/null +++ b/doc/algorithms/vision/object_detection_mxnet.rst @@ -0,0 +1,9 @@ +########################## +Object Detection - MxNet +########################## + +This is a supervised object detection algorithm which supports fine-tuning of many pre-trained models available in MXNet. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/object_detection_mxnet_gluoncv.rst b/doc/algorithms/vision/object_detection_mxnet_gluoncv.rst new file mode 100644 index 0000000000..857360b68e --- /dev/null +++ b/doc/algorithms/vision/object_detection_mxnet_gluoncv.rst @@ -0,0 +1,19 @@ +################################## +Object Detection - MxNet GluonCV +################################## + + +The Amazon SageMaker Object Detection algorithm detects and classifies objects in images using a single deep neural network. +It is a supervised learning algorithm that takes images as input and identifies all instances of objects within the image scene. +The object is categorized into one of the classes in a specified collection with a confidence score that it belongs to the class. +Its location and scale in the image are indicated by a rectangular bounding box. It uses the `Single Shot multibox Detector (SSD) `__ +framework and supports two base networks: `VGG `__ and `ResNet `__. The network can be trained from scratch, +or trained with models that have been pre-trained on the `ImageNet `__ dataset. + +For a sample notebook that shows how to use the SageMaker Object Detection algorithm to train and host a model on the `Caltech Birds (CUB 200 2011) `__ +dataset using the Single Shot multibox Detector algorithm, see `Amazon SageMaker Object Detection for Bird Species `__. +For instructions how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. +Once you have created a notebook instance and opened it, select the SageMaker Examples tab to see a list of all the SageMaker samples. The object detection example notebook using the Object Detection +algorithm is located in the Introduction to Amazon Algorithms section. To open a notebook, click on its Use tab and select Create copy. + +For detailed documentation, please refer to the `Sagemaker Object Detection Algorithm `__ diff --git a/doc/algorithms/vision/object_detection_pytorch.rst b/doc/algorithms/vision/object_detection_pytorch.rst new file mode 100644 index 0000000000..aa703e74b5 --- /dev/null +++ b/doc/algorithms/vision/object_detection_pytorch.rst @@ -0,0 +1,9 @@ +########################### +Object Detection - PyTorch +########################### + +This is a supervised object detection algorithm which supports fine-tuning of many pre-trained models available in Pytorch Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/object_detection_tensorflow.rst b/doc/algorithms/vision/object_detection_tensorflow.rst new file mode 100644 index 0000000000..2536322847 --- /dev/null +++ b/doc/algorithms/vision/object_detection_tensorflow.rst @@ -0,0 +1,9 @@ +############################### +Object Detection - TensorFlow +############################### + +This is a supervised object detection algorithm which supports fine-tuning of many pre-trained models available in Tensorflow Hub. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Object Detection for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/semantic_segmentation_mxnet.rst b/doc/algorithms/vision/semantic_segmentation_mxnet.rst new file mode 100644 index 0000000000..b0c60cd560 --- /dev/null +++ b/doc/algorithms/vision/semantic_segmentation_mxnet.rst @@ -0,0 +1,9 @@ +############################## +Semantic Segmentation - MxNet +############################## + +This is a supervised semantic segmentation algorithm which supports fine-tuning of many pre-trained models available in MXNet. The following +`sample notebook `__ +demonstrates how to use the Sagemaker Python SDK for Semantic Segmentation for using these algorithms. + +For detailed documentation please refer :ref:`Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK ` diff --git a/doc/algorithms/vision/semantic_segmentation_mxnet_gluoncv.rst b/doc/algorithms/vision/semantic_segmentation_mxnet_gluoncv.rst new file mode 100644 index 0000000000..53e532f6ea --- /dev/null +++ b/doc/algorithms/vision/semantic_segmentation_mxnet_gluoncv.rst @@ -0,0 +1,43 @@ +##################################### +Semantic Segmentation - MxNet GluonCV +##################################### + +The SageMaker semantic segmentation algorithm provides a fine-grained, pixel-level approach to developing computer vision applications. +It tags every pixel in an image with a class label from a predefined set of classes. Tagging is fundamental for understanding scenes, which is +critical to an increasing number of computer vision applications, such as self-driving vehicles, medical imaging diagnostics, and robot sensing. + +For comparison, the `SageMaker Image Classification Algorithm `__ is a +supervised learning algorithm that analyzes only whole images, classifying them into one of multiple output categories. The +`Object Detection Algorithm `__ is a supervised learning algorithm that detects and +classifies all instances of an object in an image. It indicates the location and scale of each object in the image with a rectangular bounding box. + +Because the semantic segmentation algorithm classifies every pixel in an image, it also provides information about the shapes of the objects contained in the image. +The segmentation output is represented as a grayscale image, called a segmentation mask. A segmentation mask is a grayscale image with the same shape as the input image. + +The SageMaker semantic segmentation algorithm is built using the `MXNet Gluon framework and the Gluon CV toolkit `__ +. It provides you with a choice of three built-in algorithms to train a deep neural network. You can use the `Fully-Convolutional Network (FCN) algorithm `__ , +`Pyramid Scene Parsing (PSP) algorithm `__, or `DeepLabV3 `__. + + +Each of the three algorithms has two distinct components: + +* The backbone (or encoder)—A network that produces reliable activation maps of features. + +* The decoder—A network that constructs the segmentation mask from the encoded activation maps. + +You also have a choice of backbones for the FCN, PSP, and DeepLabV3 algorithms: `ResNet50 or ResNet101 `__. +These backbones include pretrained artifacts that were originally trained on the `ImageNet `__ classification task. You can fine-tune these backbones +for segmentation using your own data. Or, you can initialize and train these networks from scratch using only your own data. The decoders are never pretrained. + +To deploy the trained model for inference, use the SageMaker hosting service. During inference, you can request the segmentation mask either as a +PNG image or as a set of probabilities for each class for each pixel. You can use these masks as part of a larger pipeline that includes additional downstream image processing or other applications. + + +For a sample Jupyter notebook that uses the SageMaker semantic segmentation algorithm to train a model and deploy it to perform inferences, see the +`Semantic Segmentation Example `__. For instructions +on how to create and access Jupyter notebook instances that you can use to run the example in SageMaker, see `Use Amazon SageMaker Notebook Instances `__. + +To see a list of all of the SageMaker samples, create and open a notebook instance, and choose the SageMaker Examples tab. The example semantic segmentation notebooks are located under +Introduction to Amazon algorithms. To open a notebook, choose its Use tab, and choose Create copy. + +For detailed documentation, please refer to the `Sagemaker Semantic Segmentation Algorithm `__ diff --git a/doc/doc_utils/jumpstart_doc_utils.py b/doc/doc_utils/jumpstart_doc_utils.py index 94096fbf1d..92a418a6b4 100644 --- a/doc/doc_utils/jumpstart_doc_utils.py +++ b/doc/doc_utils/jumpstart_doc_utils.py @@ -140,6 +140,7 @@ def create_jumpstart_model_table(): file_content = [] + file_content.append(".. _all-pretrained-models:\n\n") file_content.append(".. |external-link| raw:: html\n\n") file_content.append(' \n\n') diff --git a/doc/index.rst b/doc/index.rst index c0269452f9..2d4ebe32c1 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -39,7 +39,7 @@ The SageMaker Python SDK supports managed training and inference for a variety o ******************************** -SageMaker First-Party Algorithms +SageMaker Built-in Algorithms ******************************** Amazon SageMaker provides implementations of some common machine learning algorithms optimized for GPU architecture and massive datasets. diff --git a/doc/overview.rst b/doc/overview.rst index f6ac8a8a2d..f6fab65bae 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -573,6 +573,8 @@ Here is an example: # When you are done using your endpoint model.sagemaker_session.delete_endpoint('my-endpoint') +.. _built-in-algos: + *********************************************************************** Use Built-in Algorithms with Pre-trained Models in SageMaker Python SDK *********************************************************************** From e46d26f0f30ae0158562000c4c3816a323782cc3 Mon Sep 17 00:00:00 2001 From: Basil Beirouti Date: Wed, 17 Aug 2022 16:10:04 -0700 Subject: [PATCH 178/526] fix: using unique name for lineage test to unblock PR checks (#3313) Co-authored-by: Basil Beirouti --- tests/integ/sagemaker/lineage/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 5e201eef42..3c416ffd36 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -141,7 +141,7 @@ def action_objs(sagemaker_session): @pytest.fixture def artifact_obj(sagemaker_session): obj = artifact.Artifact.create( - artifact_name="SDKIntegrationTest", + artifact_name=name(), artifact_type="SDKIntegrationTest", source_uri=name(), properties={"k1": "v1"}, From 5e93760a1d74cc31b6557ea5960e672e34ac3fb2 Mon Sep 17 00:00:00 2001 From: Basil Beirouti Date: Wed, 17 Aug 2022 17:28:00 -0700 Subject: [PATCH 179/526] feature: adding workgroup functionality to athena query (#3276) Co-authored-by: Basil Beirouti --- src/sagemaker/feature_store/feature_group.py | 6 +++++- src/sagemaker/session.py | 6 ++++++ tests/unit/sagemaker/feature_store/test_feature_store.py | 6 +++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index 6e6caa6988..0d7c72783c 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -81,7 +81,9 @@ class AthenaQuery: _result_bucket: str = attr.ib(init=False, default=None) _result_file_prefix: str = attr.ib(init=False, default=None) - def run(self, query_string: str, output_location: str, kms_key: str = None) -> str: + def run( + self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None + ) -> str: """Execute a SQL query given a query string, output location and kms key. This method executes the SQL query using Athena and outputs the results to output_location @@ -91,6 +93,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s query_string: SQL query string. output_location: S3 URI of the query result. kms_key: KMS key id. If set, will be used to encrypt the query result file. + workgroup (str): The name of the workgroup in which the query is being started. Returns: Execution id of the query. @@ -101,6 +104,7 @@ def run(self, query_string: str, output_location: str, kms_key: str = None) -> s query_string=query_string, output_location=output_location, kms_key=kms_key, + workgroup=workgroup, ) self._current_query_execution_id = response["QueryExecutionId"] parse_result = urlparse(output_location, allow_fragments=False) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 221434d7db..bd7ae440c5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4187,6 +4187,7 @@ def start_query_execution( query_string: str, output_location: str, kms_key: str = None, + workgroup: str = None, ) -> Dict[str, str]: """Start Athena query execution. @@ -4196,6 +4197,8 @@ def start_query_execution( query_string (str): SQL expression. output_location (str): S3 location of the output file. kms_key (str): KMS key id will be used to encrypt the result if given. + workgroup (str): The name of the workgroup in which the query is being started. + If the workgroup is not specified, the default workgroup is used. Returns: Response dict from the service. @@ -4210,6 +4213,9 @@ def start_query_execution( ) kwargs.update(ResultConfiguration=result_config) + if workgroup: + kwargs.update(WorkGroup=workgroup) + athena_client = self.boto_session.client("athena", region_name=self.boto_region_name) return athena_client.start_query_execution(**kwargs) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 8f2f0eb3f9..21c22512a2 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -468,14 +468,18 @@ def query(sagemaker_session_mock): def test_athena_query_run(sagemaker_session_mock, query): + WORKGROUP = "workgroup" sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run(query_string="query", output_location="s3://some-bucket/some-path") + query.run( + query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP + ) sagemaker_session_mock.start_query_execution.assert_called_with( catalog="catalog", database="database", query_string="query", output_location="s3://some-bucket/some-path", kms_key=None, + workgroup=WORKGROUP, ) assert "some-bucket" == query._result_bucket assert "some-path" == query._result_file_prefix From 7a3274cfbcaaafd58ba1f24236910d0feb8fe63f Mon Sep 17 00:00:00 2001 From: Allen Liu Date: Wed, 17 Aug 2022 23:06:43 -0700 Subject: [PATCH 180/526] change: disable debugger/profiler in cgk region (#3312) * change: disable debugger/profiler in cgk region * fix: indent and format fix Co-authored-by: Basil Beirouti --- src/sagemaker/estimator.py | 42 +++++-- src/sagemaker/fw_utils.py | 22 +++- .../sagemaker/tensorflow/test_estimator.py | 1 + tests/unit/sagemaker/workflow/test_steps.py | 4 + tests/unit/test_estimator.py | 110 +++++++++++++++++- tests/unit/test_fw_utils.py | 7 ++ 6 files changed, 171 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 35c726cd0a..4800c9ed11 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -823,6 +823,29 @@ def _prepare_debugger_for_training(self): self.debugger_hook_config.s3_output_path = self.output_path self.debugger_rule_configs = self._prepare_debugger_rules() self._prepare_collection_configs() + self._validate_and_set_debugger_configs() + if not self.debugger_hook_config: + if self.environment is None: + self.environment = {} + self.environment[DEBUGGER_FLAG] = "0" + + def _validate_and_set_debugger_configs(self): + """Set defaults for debugging.""" + region_supports_debugger = _region_supports_debugger( + self.sagemaker_session.boto_region_name + ) + + if region_supports_debugger: + if self.debugger_hook_config in [None, {}]: + self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path) + else: + if self.debugger_hook_config is not False and self.debugger_hook_config: + # when user set debugger config in a unsupported region + raise ValueError( + "Current region does not support debugger but debugger hook config is set!" + ) + # disable debugger in unsupported regions + self.debugger_hook_config = False def _prepare_debugger_rules(self): """Set any necessary values in debugger rules, if they are provided.""" @@ -1766,6 +1789,8 @@ def enable_default_profiling(self): Debugger monitoring is disabled. """ self._ensure_latest_training_job() + if not _region_supports_debugger(self.sagemaker_session.boto_region_name): + raise ValueError("Current region does not support profiler / debugger!") training_job_details = self.latest_training_job.describe() @@ -1799,6 +1824,8 @@ def disable_profiling(self): """ self._ensure_latest_training_job() + if not _region_supports_debugger(self.sagemaker_session.boto_region_name): + raise ValueError("Current region does not support profiler / debugger!") training_job_details = self.latest_training_job.describe() @@ -1852,6 +1879,8 @@ def update_profiler( """ self._ensure_latest_training_job() + if not _region_supports_debugger(self.sagemaker_session.boto_region_name): + raise ValueError("Current region does not support profiler / debugger!") if ( not rules @@ -2872,13 +2901,7 @@ def _script_mode_hyperparam_update(self, code_dir: str, script: str) -> None: def _validate_and_set_debugger_configs(self): """Set defaults for debugging.""" - if self.debugger_hook_config is None and _region_supports_debugger( - self.sagemaker_session.boto_region_name - ): - self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path) - elif not self.debugger_hook_config: - # set hook config to False if _region_supports_debugger is False - self.debugger_hook_config = False + super(Framework, self)._validate_and_set_debugger_configs() # Disable debugger if checkpointing is enabled by the customer if self.checkpoint_s3_uri and self.checkpoint_local_path and self.debugger_hook_config: @@ -2901,11 +2924,6 @@ def _validate_and_set_debugger_configs(self): ) self.debugger_hook_config = False - if self.debugger_hook_config is False: - if self.environment is None: - self.environment = {} - self.environment[DEBUGGER_FLAG] = "0" - def _model_source_dir(self): """Get the appropriate value to pass as ``source_dir`` to a model constructor. diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 5b7b5da656..4f87d32d5f 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -53,8 +53,26 @@ "only one worker per host regardless of the number of GPUs." ) -DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",) -PROFILER_UNSUPPORTED_REGIONS = ("us-iso-east-1",) +DEBUGGER_UNSUPPORTED_REGIONS = ( + "us-iso-east-1", + "ap-southeast-3", + "ap-southeast-4", + "eu-south-2", + "me-central-1", + "ap-south-2", + "eu-central-2", + "us-gov-east-1", +) +PROFILER_UNSUPPORTED_REGIONS = ( + "us-iso-east-1", + "ap-southeast-3", + "ap-southeast-4", + "eu-south-2", + "me-central-1", + "ap-south-2", + "eu-central-2", + "us-gov-east-1", +) SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge") SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = ( diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 2e7576421f..37e548de9e 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -483,6 +483,7 @@ def test_fit_ps(time, strftime, sagemaker_session): expected_train_args = _create_train_job("1.11", ps=True, py_version="py2") expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True) + expected_train_args["environment"] = {"USE_SMDEBUG": "0"} actual_train_args = sagemaker_session.method_calls[0][2] assert actual_train_args == expected_train_args diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index cd8b7522d1..42157272b0 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -370,6 +370,10 @@ def test_training_step_base_estimator(sagemaker_session): }, "RoleArn": ROLE, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, + "DebugHookConfig": { + "S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}}, + "CollectionConfigurations": [], + }, "ProfilerConfig": { "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}}, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 5906dcb0a1..0f8954e984 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -726,6 +726,110 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): assert args.get("profiler_rule_configs") is None +@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) +def test_framework_with_debugger_config_set_up_in_unsupported_region(region): + with pytest.raises(ValueError) as error: + boto_mock = Mock(name="boto_session", region_name=region) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sms, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + debugger_hook_config=DebuggerHookConfig(s3_output_path="s3://output"), + ) + f.fit("s3://mydata") + + assert "Current region does not support debugger but debugger hook config is set!" in str(error) + + +@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) +def test_framework_enable_profiling_in_unsupported_region(region): + with pytest.raises(ValueError) as error: + boto_mock = Mock(name="boto_session", region_name=region) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sms, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) + f.fit("s3://mydata") + f.enable_default_profiling() + + assert "Current region does not support profiler / debugger!" in str(error) + + +@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) +def test_framework_update_profiling_in_unsupported_region(region): + with pytest.raises(ValueError) as error: + boto_mock = Mock(name="boto_session", region_name=region) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sms, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) + f.fit("s3://mydata") + f.update_profiler(system_monitor_interval_millis=1000) + + assert "Current region does not support profiler / debugger!" in str(error) + + +@pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) +def test_framework_disable_profiling_in_unsupported_region(region): + with pytest.raises(ValueError) as error: + boto_mock = Mock(name="boto_session", region_name=region) + sms = MagicMock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_client=None, + s3_resource=None, + ) + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sms, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) + f.fit("s3://mydata") + f.disable_profiling() + + assert "Current region does not support profiler / debugger!" in str(error) + + def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session): with pytest.raises(RuntimeError) as error: f = DummyFramework( @@ -2683,6 +2787,7 @@ def test_generic_to_fit_no_input(time, sagemaker_session): args.pop("job_name") args.pop("role") + args.pop("debugger_hook_config") assert args == NO_INPUT_TRAIN_CALL @@ -2707,6 +2812,7 @@ def test_generic_to_fit_no_hps(time, sagemaker_session): args.pop("job_name") args.pop("role") + args.pop("debugger_hook_config") assert args == BASE_TRAIN_CALL @@ -2733,6 +2839,7 @@ def test_generic_to_fit_with_hps(time, sagemaker_session): args.pop("job_name") args.pop("role") + args.pop("debugger_hook_config") assert args == HP_TRAIN_CALL @@ -2764,6 +2871,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session): args.pop("job_name") args.pop("role") + args.pop("debugger_hook_config") assert args == EXP_TRAIN_CALL @@ -2917,6 +3025,7 @@ def test_generic_to_deploy(time, sagemaker_session): args.pop("job_name") args.pop("role") + args.pop("debugger_hook_config") assert args == HP_TRAIN_CALL @@ -3727,7 +3836,6 @@ def test_script_mode_estimator_same_calls_as_framework( source_dir=script_uri, image_uri=IMAGE_URI, model_uri=model_uri, - environment={"USE_SMDEBUG": "0"}, dependencies=[], debugger_hook_config={}, ) diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index fbbc27be37..e378b7a0a2 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -552,6 +552,13 @@ def test_region_supports_debugger_feature_returns_true_for_supported_regions(): def test_region_supports_debugger_feature_returns_false_for_unsupported_regions(): assert fw_utils._region_supports_debugger("us-iso-east-1") is False + assert fw_utils._region_supports_debugger("ap-southeast-3") is False + assert fw_utils._region_supports_debugger("ap-southeast-4") is False + assert fw_utils._region_supports_debugger("eu-south-2") is False + assert fw_utils._region_supports_debugger("me-central-1") is False + assert fw_utils._region_supports_debugger("ap-south-2") is False + assert fw_utils._region_supports_debugger("eu-central-2") is False + assert fw_utils._region_supports_debugger("us-gov-east-1") is False def test_warn_if_parameter_server_with_multi_gpu(caplog): From 37c58efdfc8f168ce77e4bdded52d7e94eb0de36 Mon Sep 17 00:00:00 2001 From: Yeldos Balgabekov Date: Thu, 18 Aug 2022 08:43:13 +0200 Subject: [PATCH 181/526] feat: Added endpoint_name to clarify.ModelConfig (#3296) * added run_bias_and_explainability method * feat: added endpoint_name to clarify.ModelConfig Co-authored-by: Yeldos Balgabekov --- src/sagemaker/clarify.py | 450 +++++++++++++++++++++++++++--------- tests/integ/test_clarify.py | 64 ++++- tests/unit/test_clarify.py | 168 +++++++++++++- 3 files changed, 572 insertions(+), 110 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 7f00a78268..0bdfa7db98 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -25,7 +25,7 @@ import tempfile from abc import ABC, abstractmethod -from typing import List, Union +from typing import List, Union, Dict from sagemaker import image_uris, s3, utils from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor @@ -172,7 +172,11 @@ def __init__( _set(joinsource, "joinsource_name_or_index", self.analysis_config) _set(facet_dataset_uri, "facet_dataset_uri", self.analysis_config) _set(facet_headers, "facet_headers", self.analysis_config) - _set(predicted_label_dataset_uri, "predicted_label_dataset_uri", self.analysis_config) + _set( + predicted_label_dataset_uri, + "predicted_label_dataset_uri", + self.analysis_config, + ) _set(predicted_label_headers, "predicted_label_headers", self.analysis_config) _set(predicted_label, "predicted_label", self.analysis_config) _set(excluded_columns, "excluded_columns", self.analysis_config) @@ -271,26 +275,33 @@ class ModelConfig: def __init__( self, - model_name, - instance_count, - instance_type, - accept_type=None, - content_type=None, - content_template=None, - custom_attributes=None, - accelerator_type=None, - endpoint_name_prefix=None, - target_model=None, + model_name: str = None, + instance_count: int = None, + instance_type: str = None, + accept_type: str = None, + content_type: str = None, + content_template: str = None, + custom_attributes: str = None, + accelerator_type: str = None, + endpoint_name_prefix: str = None, + target_model: str = None, + endpoint_name: str = None, ): r"""Initializes a configuration of a model and the endpoint to be created for it. Args: model_name (str): Model name (as created by `CreateModel `_. + Cannot be set when ``endpoint_name`` is set. + Must be set with ``instance_count``, ``instance_type`` instance_count (int): The number of instances of a new endpoint for model inference. + Cannot be set when ``endpoint_name`` is set. + Must be set with ``model_name``, ``instance_type`` instance_type (str): The type of `EC2 instance `_ to use for model inference; for example, ``"ml.c5.xlarge"``. + Cannot be set when ``endpoint_name`` is set. + Must be set with ``instance_count``, ``model_name`` accept_type (str): The model output format to be used for getting inferences with the shadow endpoint. Valid values are ``"text/csv"`` for CSV and ``"application/jsonlines"``. Default is the same as ``content_type``. @@ -320,17 +331,41 @@ def __init__( target_model (str): Sets the target model name when using a multi-model endpoint. For more information about multi-model endpoints, see https://docs.aws.amazon.com/sagemaker/latest/dg/multi-model-endpoints.html + endpoint_name (str): Sets the endpoint_name when re-uses an existing endpoint. + Cannot be set when ``model_name``, ``instance_count``, + and ``instance_type`` set Raises: - ValueError: when the ``endpoint_name_prefix`` is invalid, ``accept_type`` is invalid, - ``content_type`` is invalid, or ``content_template`` has no placeholder "features" + ValueError: when the + - ``endpoint_name_prefix`` is invalid, + - ``accept_type`` is invalid, + - ``content_type`` is invalid, + - ``content_template`` has no placeholder "features" + - both [``endpoint_name``] + AND [``model_name``, ``instance_count``, ``instance_type``] are set + - both [``endpoint_name``] AND [``endpoint_name_prefix``] are set """ - self.predictor_config = { - "model_name": model_name, - "instance_type": instance_type, - "initial_instance_count": instance_count, - } - if endpoint_name_prefix is not None: + + # validation + _model_endpoint_config_rule = ( + all([model_name, instance_count, instance_type]), + all([endpoint_name]), + ) + assert any(_model_endpoint_config_rule) and not all(_model_endpoint_config_rule) + if endpoint_name: + assert not endpoint_name_prefix + + # main init logic + self.predictor_config = ( + { + "model_name": model_name, + "instance_type": instance_type, + "initial_instance_count": instance_count, + } + if not endpoint_name + else {"endpoint_name": endpoint_name} + ) + if endpoint_name_prefix: if re.search("^[a-zA-Z0-9](-*[a-zA-Z0-9])", endpoint_name_prefix) is None: raise ValueError( "Invalid endpoint_name_prefix." @@ -491,7 +526,10 @@ def __init__(self, features=None, grid_resolution=15, top_k_features=10): top_k_features (int): Sets the number of top SHAP attributes used to compute partial dependence plots. """ # noqa E501 - self.pdp_config = {"grid_resolution": grid_resolution, "top_k_features": top_k_features} + self.pdp_config = { + "grid_resolution": grid_resolution, + "top_k_features": top_k_features, + } if features is not None: self.pdp_config["features"] = features @@ -824,7 +862,11 @@ def __init__( image_config (:class:`~sagemaker.clarify.ImageConfig`): Config for handling image features. Default is None. """ # noqa E501 # pylint: disable=c0301 - if agg_method is not None and agg_method not in ["mean_abs", "median", "mean_sq"]: + if agg_method is not None and agg_method not in [ + "mean_abs", + "median", + "mean_sq", + ]: raise ValueError( f"Invalid agg_method {agg_method}." f" Please choose mean_abs, median, or mean_sq." ) @@ -1167,7 +1209,11 @@ def run_post_training_bias( * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. """ # noqa E501 # pylint: disable=c0301 analysis_config = _AnalysisConfigGenerator.bias_post_training( - data_config, data_bias_config, model_predicted_label_config, methods, model_config + data_config, + data_bias_config, + model_predicted_label_config, + methods, + model_config, ) # when name is either not provided (is None) or an empty string ("") job_name = job_name or utils.name_from_base( @@ -1368,68 +1414,198 @@ def run_explainability( experiment_config, ) + def run_bias_and_explainability( + self, + data_config: DataConfig, + model_config: ModelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + bias_config: BiasConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + model_predicted_label_config: ModelPredictedLabelConfig = None, + wait=True, + logs=True, + job_name=None, + kms_key=None, + experiment_config=None, + ): + """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions. + + For bias: + Computes metrics for both the pre-training and the post-training methods. + To calculate post-training methods, it spins up a model endpoint and runs inference over the + input examples in 's3_data_input_path' (from the :class:`~sagemaker.clarify.DataConfig`) + to obtain predicted labels. + + For Explainability: + Spins up a model endpoint. + + Currently, only SHAP and Partial Dependence Plots (PDP) are supported + as explainability methods. + You can request both methods or one at a time with the ``explainability_config`` parameter. + + When SHAP is requested in the ``explainability_config``, + the SHAP algorithm calculates the feature importance for each input example + in the ``s3_data_input_path`` of the :class:`~sagemaker.clarify.DataConfig`, + by creating ``num_samples`` copies of the example with a subset of features + replaced with values from the ``baseline``. + It then runs model inference to see how the model's prediction changes with the replaced + features. If the model output returns multiple scores importance is computed for each score. + Across examples, feature importance is aggregated using ``agg_method``. + + When PDP is requested in the ``explainability_config``, + the PDP algorithm calculates the dependence of the target response + on the input features and marginalizes over the values of all other input features. + The Partial Dependence Plots are included in the output + `report `__ + and the corresponding values are included in the analysis output. + + Args: + data_config (:class:`~sagemaker.clarify.DataConfig`): Config of the input/output data. + model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its + endpoint to be created. + explainability_config (:class:`~sagemaker.clarify.ExplainabilityConfig` or list): + Config of the specific explainability method or a list of + :class:`~sagemaker.clarify.ExplainabilityConfig` objects. + Currently, SHAP and PDP are the two methods supported. + You can request multiple methods at once by passing in a list of + `~sagemaker.clarify.ExplainabilityConfig`. + bias_config (:class:`~sagemaker.clarify.BiasConfig`): Config of sensitive groups. + pre_training_methods (str or list[str]): Selector of a subset of potential metrics: + ["`CI `_", + "`DPL `_", + "`KL `_", + "`JS `_", + "`LP `_", + "`TVD `_", + "`KS `_", + "`CDDL `_"]. + Defaults to str "all" to run all metrics if left unspecified. + post_training_methods (str or list[str]): Selector of a subset of potential metrics: + ["`DPPL `_" + , "`DI `_", + "`DCA `_", + "`DCR `_", + "`RD `_", + "`DAR `_", + "`DRR `_", + "`AD `_", + "`CDDPL `_ + ", "`TE `_", + "`FT `_"]. + Defaults to str "all" to run all metrics if left unspecified. + model_predicted_label_config ( + int or + str or + :class:`~sagemaker.clarify.ModelPredictedLabelConfig` + ): + Index or JSONPath to locate the predicted scores in the model output. This is not + required if the model output is a single score. Alternatively, it can be an instance + of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + to provide more parameters like ``label_headers``. + wait (bool): Whether the call should wait until the job completes (default: True). + logs (bool): Whether to show the logs produced by the job. + Only meaningful when ``wait`` is True (default: True). + job_name (str): Processing job name. When ``job_name`` is not specified, + if ``job_name_prefix`` in :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + is specified, the job name will be composed of ``job_name_prefix`` and current + timestamp; otherwise use ``"Clarify-Explainability"`` as prefix. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). + experiment_config (dict[str, str]): Experiment management configuration. + Optionally, the dict can contain three keys: + ``'ExperimentName'``, ``'TrialName'``, and ``'TrialComponentDisplayName'``. + + The behavior of setting these keys is as follows: + + * If ``'ExperimentName'`` is supplied but ``'TrialName'`` is not, a Trial will be + automatically created and the job's Trial Component associated with the Trial. + * If ``'TrialName'`` is supplied and the Trial already exists, + the job's Trial Component will be associated with the Trial. + * If both ``'ExperimentName'`` and ``'TrialName'`` are not supplied, + the Trial Component will be unassociated. + * ``'TrialComponentDisplayName'`` is used for display in Amazon SageMaker Studio. + """ # noqa E501 # pylint: disable=c0301 + analysis_config = _AnalysisConfigGenerator.bias_and_explainability( + data_config, + model_config, + model_predicted_label_config, + explainability_config, + bias_config, + pre_training_methods, + post_training_methods, + ) + # when name is either not provided (is None) or an empty string ("") + job_name = job_name or utils.name_from_base( + self.job_name_prefix or "Clarify-Bias-And-Explainability" + ) + return self._run( + data_config, + analysis_config, + wait, + logs, + job_name, + kms_key, + experiment_config, + ) + class _AnalysisConfigGenerator: """Creates analysis_config objects for different type of runs.""" + @classmethod + def bias_and_explainability( + cls, + data_config: DataConfig, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + bias_config: BiasConfig, + pre_training_methods: Union[str, List[str]] = "all", + post_training_methods: Union[str, List[str]] = "all", + ): + """Generates a config for Bias and Explainability""" + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods( + analysis_config, + pre_training_methods=pre_training_methods, + post_training_methods=post_training_methods, + explainability_config=explainability_config, + ) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config + @classmethod def explainability( cls, data_config: DataConfig, model_config: ModelConfig, - model_scores: ModelPredictedLabelConfig, - explainability_config: ExplainabilityConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], ): """Generates a config for Explainability""" - analysis_config = data_config.get_config() - predictor_config = model_config.get_predictor_config() - if isinstance(model_scores, ModelPredictedLabelConfig): - ( - probability_threshold, - predicted_label_config, - ) = model_scores.get_predictor_config() - _set(probability_threshold, "probability_threshold", analysis_config) - predictor_config.update(predicted_label_config) - else: - _set(model_scores, "label", predictor_config) - - explainability_methods = {} - if isinstance(explainability_config, list): - if len(explainability_config) == 0: - raise ValueError("Please provide at least one explainability config.") - for config in explainability_config: - explain_config = config.get_explainability_config() - explainability_methods.update(explain_config) - if not len(explainability_methods.keys()) == len(explainability_config): - raise ValueError("Duplicate explainability configs are provided") - if ( - "shap" not in explainability_methods - and explainability_methods["pdp"].get("features", None) is None - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") - else: - if ( - isinstance(explainability_config, PDPConfig) - and explainability_config.get_explainability_config()["pdp"].get("features", None) - is None - ): - raise ValueError("PDP features must be provided when ShapConfig is not provided") - explainability_methods = explainability_config.get_explainability_config() - analysis_config["methods"] = explainability_methods - analysis_config["predictor"] = predictor_config - return cls._common(analysis_config) + analysis_config = data_config.analysis_config + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + analysis_config = cls._add_methods( + analysis_config, explainability_config=explainability_config + ) + return analysis_config @classmethod def bias_pre_training( - cls, data_config: DataConfig, bias_config: BiasConfig, methods: Union[str, List[str]] + cls, + data_config: DataConfig, + bias_config: BiasConfig, + methods: Union[str, List[str]], ): """Generates a config for Bias Pre Training""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "methods": {"pre_training_bias": {"methods": methods}}, - } - return cls._common(analysis_config) + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods(analysis_config, pre_training_methods=methods) + return analysis_config @classmethod def bias_post_training( @@ -1441,21 +1617,12 @@ def bias_post_training( model_config: ModelConfig, ): """Generates a config for Bias Post Training""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "predictor": {**model_config.get_predictor_config()}, - "methods": {"post_training_bias": {"methods": methods}}, - } - if model_predicted_label_config: - ( - probability_threshold, - predictor_config, - ) = model_predicted_label_config.get_predictor_config() - if predictor_config: - analysis_config["predictor"].update(predictor_config) - _set(probability_threshold, "probability_threshold", analysis_config) - return cls._common(analysis_config) + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods(analysis_config, post_training_methods=methods) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config @classmethod def bias( @@ -1468,16 +1635,28 @@ def bias( post_training_methods: Union[str, List[str]] = "all", ): """Generates a config for Bias""" - analysis_config = { - **data_config.get_config(), - **bias_config.get_config(), - "predictor": model_config.get_predictor_config(), - "methods": { - "pre_training_bias": {"methods": pre_training_methods}, - "post_training_bias": {"methods": post_training_methods}, - }, - } - if model_predicted_label_config: + analysis_config = {**data_config.get_config(), **bias_config.get_config()} + analysis_config = cls._add_methods( + analysis_config, + pre_training_methods=pre_training_methods, + post_training_methods=post_training_methods, + ) + analysis_config = cls._add_predictor( + analysis_config, model_config, model_predicted_label_config + ) + return analysis_config + + @classmethod + def _add_predictor( + cls, + analysis_config: Dict, + model_config: ModelConfig, + model_predicted_label_config: ModelPredictedLabelConfig, + ): + """Extends analysis config with predictor.""" + analysis_config = {**analysis_config} + analysis_config["predictor"] = model_config.get_predictor_config() + if isinstance(model_predicted_label_config, ModelPredictedLabelConfig): ( probability_threshold, predictor_config, @@ -1485,17 +1664,82 @@ def bias( if predictor_config: analysis_config["predictor"].update(predictor_config) _set(probability_threshold, "probability_threshold", analysis_config) - return cls._common(analysis_config) - - @staticmethod - def _common(analysis_config): - """Extends analysis config with common values""" - analysis_config["methods"]["report"] = { - "name": "report", - "title": "Analysis Report", - } + else: + _set(model_predicted_label_config, "label", analysis_config["predictor"]) + return analysis_config + + @classmethod + def _add_methods( + cls, + analysis_config: Dict, + pre_training_methods: Union[str, List[str]] = None, + post_training_methods: Union[str, List[str]] = None, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]] = None, + report=True, + ): + """Extends analysis config with methods.""" + # validate + params = [pre_training_methods, post_training_methods, explainability_config] + if not any(params): + raise AttributeError( + "analysis_config must have at least one working method: " + "One of the " + "`pre_training_methods`, `post_training_methods`, `explainability_config`." + ) + + # main logic + analysis_config = {**analysis_config} + if "methods" not in analysis_config: + analysis_config["methods"] = {} + + if report: + analysis_config["methods"]["report"] = { + "name": "report", + "title": "Analysis Report", + } + + if pre_training_methods: + analysis_config["methods"]["pre_training_bias"] = {"methods": pre_training_methods} + + if post_training_methods: + analysis_config["methods"]["post_training_bias"] = {"methods": post_training_methods} + + if explainability_config is not None: + explainability_methods = cls._merge_explainability_configs(explainability_config) + analysis_config["methods"] = { + **analysis_config["methods"], + **explainability_methods, + } return analysis_config + @classmethod + def _merge_explainability_configs( + cls, + explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]], + ): + """Merges explainability configs, when more than one.""" + if isinstance(explainability_config, list): + explainability_methods = {} + if len(explainability_config) == 0: + raise ValueError("Please provide at least one explainability config.") + for config in explainability_config: + explain_config = config.get_explainability_config() + explainability_methods.update(explain_config) + if not len(explainability_methods) == len(explainability_config): + raise ValueError("Duplicate explainability configs are provided") + if ( + "shap" not in explainability_methods + and "features" not in explainability_methods["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + return explainability_methods + if ( + isinstance(explainability_config, PDPConfig) + and "features" not in explainability_config.get_explainability_config()["pdp"] + ): + raise ValueError("PDP features must be provided when ShapConfig is not provided") + return explainability_config.get_explainability_config() + def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): """Uploads the local ``analysis_config_file`` to the ``s3_output_path``. diff --git a/tests/integ/test_clarify.py b/tests/integ/test_clarify.py index a107c00859..eaa75bce64 100644 --- a/tests/integ/test_clarify.py +++ b/tests/integ/test_clarify.py @@ -138,7 +138,9 @@ def data_path_no_label_index(training_set_no_label): def data_path_label_index(training_set_label_index): features, label, index = training_set_label_index data = pd.concat( - [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], axis=1, sort=False + [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(index)], + axis=1, + sort=False, ) with tempfile.TemporaryDirectory() as tmpdirname: filename = os.path.join(tmpdirname, "train_label_index.csv") @@ -151,7 +153,12 @@ def data_path_label_index(training_set_label_index): def data_path_label_index_6col(training_set_label_index): features, label, index = training_set_label_index data = pd.concat( - [pd.DataFrame(label), pd.DataFrame(features), pd.DataFrame(features), pd.DataFrame(index)], + [ + pd.DataFrame(label), + pd.DataFrame(features), + pd.DataFrame(features), + pd.DataFrame(index), + ], axis=1, sort=False, ) @@ -551,7 +558,10 @@ def test_pre_training_bias(clarify_processor, data_config, data_bias_config, sag def test_pre_training_bias_facets_not_included( - clarify_processor, data_config_facets_not_included, data_bias_config, sagemaker_session + clarify_processor, + data_config_facets_not_included, + data_bias_config, + sagemaker_session, ): with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): clarify_processor.run_pre_training_bias( @@ -643,7 +653,9 @@ def test_post_training_bias_facets_not_included_excluded_columns( <= 1.0 ) check_analysis_config( - data_config_facets_not_included_multiple_files, sagemaker_session, "post_training_bias" + data_config_facets_not_included_multiple_files, + sagemaker_session, + "post_training_bias", ) @@ -704,6 +716,50 @@ def test_shap(clarify_processor, data_config, model_config, shap_config, sagemak check_analysis_config(data_config, sagemaker_session, "shap") +def test_bias_and_explainability( + clarify_processor, + data_config, + model_config, + shap_config, + data_bias_config, + sagemaker_session, +): + with timeout.timeout(minutes=CLARIFY_DEFAULT_TIMEOUT_MINUTES): + clarify_processor.run_bias_and_explainability( + data_config, + model_config, + shap_config, + data_bias_config, + pre_training_methods="all", + post_training_methods="all", + model_predicted_label_config="score", + job_name=utils.unique_name_from_base("clarify-bias-and-explainability"), + wait=True, + ) + analysis_result_json = s3.S3Downloader.read_file( + data_config.s3_output_path + "/analysis.json", + sagemaker_session, + ) + analysis_result = json.loads(analysis_result_json) + assert ( + math.fabs( + analysis_result["explanations"]["kernel_shap"]["label0"]["global_shap_values"]["F2"] + ) + <= 1 + ) + check_analysis_config(data_config, sagemaker_session, "shap") + + assert ( + math.fabs( + analysis_result["post_training_bias_metrics"]["facets"]["F1"][0]["metrics"][0][ + "value" + ] + ) + <= 1.0 + ) + check_analysis_config(data_config, sagemaker_session, "post_training_bias") + + def check_analysis_config(data_config, sagemaker_session, method): analysis_config_json = s3.S3Downloader.read_file( data_config.s3_output_path + "/analysis_config.json", diff --git a/tests/unit/test_clarify.py b/tests/unit/test_clarify.py index 7375657944..c9400a7be4 100644 --- a/tests/unit/test_clarify.py +++ b/tests/unit/test_clarify.py @@ -232,7 +232,8 @@ def test_invalid_bias_config(): # Two facets but only one value with pytest.raises( - ValueError, match="The number of facet names doesn't match the number of facet values" + ValueError, + match="The number of facet names doesn't match the number of facet values", ): BiasConfig( label_values_or_threshold=[1], @@ -295,7 +296,10 @@ def test_invalid_bias_config(): { "facet": [ {"name_or_index": "Feature1", "value_or_threshold": [1]}, - {"name_or_index": 1, "value_or_threshold": ["category1, category2"]}, + { + "name_or_index": 1, + "value_or_threshold": ["category1, category2"], + }, {"name_or_index": "Feature3", "value_or_threshold": [0.5]}, ], }, @@ -735,6 +739,41 @@ def pdp_config(): return PDPConfig(features=["F1", "F2"], grid_resolution=20) +def test_model_config_validations(): + new_model_endpoint_definition = { + "model_name": "xgboost-model", + "instance_type": "ml.c5.xlarge", + "instance_count": 1, + } + existing_endpoint_definition = {"endpoint_name": "existing_endpoint"} + + with pytest.raises(AssertionError): + # should be one of them + ModelConfig( + **new_model_endpoint_definition, + **existing_endpoint_definition, + ) + + with pytest.raises(AssertionError): + # should be one of them + ModelConfig( + endpoint_name_prefix="prefix", + **existing_endpoint_definition, + ) + + # success path for new model + assert ModelConfig(**new_model_endpoint_definition).predictor_config == { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "model_name": "xgboost-model", + } + + # success path for existing endpoint + assert ( + ModelConfig(**existing_endpoint_definition).predictor_config == existing_endpoint_definition + ) + + @patch("sagemaker.utils.name_from_base", return_value=JOB_NAME) def test_pre_training_bias( name_from_base, @@ -1094,7 +1133,9 @@ def test_explainability_with_invalid_config( "initial_instance_count": 1, } with pytest.raises( - AttributeError, match="'NoneType' object has no attribute 'get_explainability_config'" + AttributeError, + match="analysis_config must have at least one working method: " + "One of the `pre_training_methods`, `post_training_methods`, `explainability_config`.", ): _run_test_explain( name_from_base, @@ -1320,6 +1361,127 @@ def test_analysis_config_generator_for_explainability(data_config, model_config) assert actual == expected +def test_analysis_config_generator_for_explainability_failing(data_config, model_config): + model_scores = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + with pytest.raises( + ValueError, + match="PDP features must be provided when ShapConfig is not provided", + ): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + PDPConfig(), + ) + + with pytest.raises(ValueError, match="Duplicate explainability configs are provided"): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + [SHAPConfig(), SHAPConfig()], + ) + + with pytest.raises( + AttributeError, + match="analysis_config must have at least one working method: " + "One of the " + "`pre_training_methods`, `post_training_methods`, `explainability_config`.", + ): + _AnalysisConfigGenerator.explainability( + data_config, + model_config, + model_scores, + [], + ) + + +def test_analysis_config_generator_for_bias_explainability( + data_config, data_bias_config, model_config +): + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_and_explainability( + data_config, + model_config, + model_predicted_label_config, + [SHAPConfig(), PDPConfig()], + data_bias_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "pdp": {"grid_resolution": 15, "top_k_features": 10}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "initial_instance_count": 1, + "instance_type": "ml.c5.xlarge", + "label_headers": ["success"], + "model_name": "xgboost-model", + "probability": "pr", + }, + } + assert actual == expected + + +def test_analysis_config_generator_for_bias_explainability_with_existing_endpoint( + data_config, data_bias_config +): + model_config = ModelConfig(endpoint_name="existing_endpoint_name") + model_predicted_label_config = ModelPredictedLabelConfig( + probability="pr", + label_headers=["success"], + ) + actual = _AnalysisConfigGenerator.bias_and_explainability( + data_config, + model_config, + model_predicted_label_config, + [SHAPConfig(), PDPConfig()], + data_bias_config, + pre_training_methods="all", + post_training_methods="all", + ) + expected = { + "dataset_type": "text/csv", + "facet": [{"name_or_index": "F1"}], + "group_variable": "F2", + "headers": ["Label", "F1", "F2", "F3", "F4"], + "joinsource_name_or_index": "F4", + "label": "Label", + "label_values_or_threshold": [1], + "methods": { + "pdp": {"grid_resolution": 15, "top_k_features": 10}, + "post_training_bias": {"methods": "all"}, + "pre_training_bias": {"methods": "all"}, + "report": {"name": "report", "title": "Analysis Report"}, + "shap": {"save_local_shap_values": True, "use_logit": False}, + }, + "predictor": { + "label_headers": ["success"], + "endpoint_name": "existing_endpoint_name", + "probability": "pr", + }, + } + assert actual == expected + + def test_analysis_config_generator_for_bias_pre_training(data_config, data_bias_config): actual = _AnalysisConfigGenerator.bias_pre_training( data_config, data_bias_config, methods="all" From 1487c662daa87d8756be4fcb81ee8c46bf431884 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 19 Aug 2022 06:29:04 +0000 Subject: [PATCH 182/526] prepare release v2.105.0 --- CHANGELOG.md | 16 ++++++++++++++++ VERSION | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acdd80d328..39cfa83cc5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,21 @@ # Changelog +## v2.105.0 (2022-08-19) + +### Features + + * Added endpoint_name to clarify.ModelConfig + * adding workgroup functionality to athena query + +### Bug Fixes and Other Changes + + * disable debugger/profiler in cgk region + * using unique name for lineage test to unblock PR checks + +### Documentation Changes + + * update first-party algorithms and structural updates + ## v2.104.0 (2022-08-17) ### Features diff --git a/VERSION b/VERSION index c574defb95..cfb30c732f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.104.1.dev0 +2.105.0 From 416d1b0cfcb8bbf13f6cb43ca7a62b13a9fd1156 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 19 Aug 2022 06:29:05 +0000 Subject: [PATCH 183/526] update development version to v2.105.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index cfb30c732f..7d96ae0efb 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.105.0 +2.105.1.dev0 From d3c8d380d755b5bfa4ee8387de0ce774bd1610c9 Mon Sep 17 00:00:00 2001 From: JEET PATEL <111349713+jepatelk@users.noreply.github.com> Date: Tue, 23 Aug 2022 03:21:22 +0530 Subject: [PATCH 184/526] change: Add CGK in config for Spark Image (#3309) --- src/sagemaker/image_uri_config/spark.json | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/spark.json b/src/sagemaker/image_uri_config/spark.json index 69347fab1a..e0e5ef44ae 100644 --- a/src/sagemaker/image_uri_config/spark.json +++ b/src/sagemaker/image_uri_config/spark.json @@ -27,7 +27,8 @@ "cn-northwest-1": "844356804704", "eu-south-1": "753923664805", "af-south-1": "309385258863", - "us-gov-west-1": "271483468897" + "us-gov-west-1": "271483468897", + "ap-southeast-3": "732049463269" }, "repository": "sagemaker-spark-processing" }, @@ -56,7 +57,8 @@ "cn-northwest-1": "844356804704", "eu-south-1": "753923664805", "af-south-1": "309385258863", - "us-gov-west-1": "271483468897" + "us-gov-west-1": "271483468897", + "ap-southeast-3": "732049463269" }, "repository": "sagemaker-spark-processing" }, @@ -85,7 +87,8 @@ "cn-northwest-1": "844356804704", "eu-south-1": "753923664805", "af-south-1": "309385258863", - "us-gov-west-1": "271483468897" + "us-gov-west-1": "271483468897", + "ap-southeast-3": "732049463269" }, "repository": "sagemaker-spark-processing" } From 18e4cd478e2c0472e87417c6099d75526c4ebd40 Mon Sep 17 00:00:00 2001 From: Rahul Venkatesh <105655261+rahven14@users.noreply.github.com> Date: Tue, 23 Aug 2022 10:08:41 +0530 Subject: [PATCH 185/526] fix: remove specifying env-vars when creating model from model package (#3301) Co-authored-by: Basil Beirouti --- src/sagemaker/model.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 8772fa724f..f81b591809 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1581,9 +1581,6 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar container_def = {"ModelPackageName": model_package_name} - if self.env != {}: - container_def["Environment"] = self.env - self._ensure_base_name_if_needed(model_package_name.split("/")[-1]) self._set_model_name_if_needed() From 38ed502964d3988acd763d692b50b9e29e3a6fa1 Mon Sep 17 00:00:00 2001 From: atqy <95724753+atqy@users.noreply.github.com> Date: Tue, 23 Aug 2022 09:50:29 -0700 Subject: [PATCH 186/526] feat: Implement Kendra Search in RTD website (#3294) --- LICENSE.txt | 18 + doc/_static/kendrasearchtools.js | 692 +++++++++++++++++++++++++++++ doc/_static/pagination.css | 17 + doc/_static/search_accessories.css | 29 ++ doc/_templates/search.html | 56 +++ doc/conf.py | 8 +- licenses/2-CLAUSE-BSD | 28 ++ 7 files changed, 847 insertions(+), 1 deletion(-) create mode 100644 doc/_static/kendrasearchtools.js create mode 100644 doc/_static/pagination.css create mode 100644 doc/_static/search_accessories.css create mode 100644 doc/_templates/search.html create mode 100644 licenses/2-CLAUSE-BSD diff --git a/LICENSE.txt b/LICENSE.txt index a1ce8c3b5e..0633468f44 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -173,3 +173,21 @@ of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS + + ====================================================================================== + Amazon SageMaker Examples Subcomponents: + + The Amazon SageMaker Examples project contains subcomponents with separate + copyright notices and license terms. Your use of the source code for the + these subcomponents is subject to the terms and conditions of the following + licenses. See licenses/ for text of these licenses. + + If a folder hierarchy is listed as subcomponent, separate listings of + further subcomponents (files or folder hierarchies) part of the hierarchy + take precedence. + + ======================================================================================= + 2-clause BSD license + ======================================================================================= + _static/kendrasearchtools.js + _templates/search.html diff --git a/doc/_static/kendrasearchtools.js b/doc/_static/kendrasearchtools.js new file mode 100644 index 0000000000..f2d47ef889 --- /dev/null +++ b/doc/_static/kendrasearchtools.js @@ -0,0 +1,692 @@ +/* + * kendrasearchtools.js + * ~~~~~~~~~~~~~~~~ + * + * A modification of searchtools.js (https://github.com/sphinx-doc/sphinx/blob/275d9/sphinx/themes/basic/static/searchtools.js) + * where the default full-text search implemented in searchtools.js is replaced with AWS Kendra searching over multiple + * websites. The default full-text search is still kept and implemented as a fallback in the case that the Kendra search doesn't work. + * + * :copyright: Copyright 2007-2021 by the Sphinx team, see AUTHORS. + * :license: BSD, see LICENSE for details. + * + */ + +if (!Scorer) { + /** + * Simple result scoring code. + */ + var Scorer = { + // Implement the following function to further tweak the score for each result + // The function takes a result array [filename, title, anchor, descr, score] + // and returns the new score. + /* + score: function(result) { + return result[4]; + }, + */ + + // query matches the full name of an object + objNameMatch: 11, + // or matches in the last dotted part of the object name + objPartialMatch: 6, + // Additive scores depending on the priority of the object + objPrio: {0: 15, // used to be importantResults + 1: 5, // used to be objectResults + 2: -5}, // used to be unimportantResults + // Used when the priority is not in the mapping. + objPrioDefault: 0, + + // query found in title + title: 15, + partialTitle: 7, + // query found in terms + term: 5, + partialTerm: 2 + }; +} + +if (!splitQuery) { + function splitQuery(query) { + return query.split(/\s+/); + } +} + +/** + * default rtd search (used as fallback) + */ +var Search = { + + _index : null, + _queued_query : null, + _pulse_status : -1, + + htmlToText : function(htmlString) { + var virtualDocument = document.implementation.createHTMLDocument('virtual'); + var htmlElement = $(htmlString, virtualDocument); + htmlElement.find('.headerlink').remove(); + docContent = htmlElement.find('[role=main]')[0]; + if(docContent === undefined) { + console.warn("Content block not found. Sphinx search tries to obtain it " + + "via '[role=main]'. Could you check your theme or template."); + return ""; + } + return docContent.textContent || docContent.innerText; + }, + + init : function() { + var params = $.getQueryParameters(); + if (params.q) { + var query = params.q[0]; + $('input[name="q"]')[0].value = query; + // this.performSearch(query); + } + }, + + loadIndex : function(url) { + $.ajax({type: "GET", url: url, data: null, + dataType: "script", cache: true, + complete: function(jqxhr, textstatus) { + if (textstatus != "success") { + document.getElementById("searchindexloader").src = url; + } + }}); + }, + + setIndex : function(index) { + var q; + this._index = index; + if ((q = this._queued_query) !== null) { + this._queued_query = null; + Search.query(q); + } + }, + + hasIndex : function() { + return this._index !== null; + }, + + deferQuery : function(query) { + this._queued_query = query; + }, + + stopPulse : function() { + this._pulse_status = 0; + }, + + startPulse : function() { + if (this._pulse_status >= 0) + return; + function pulse() { + var i; + Search._pulse_status = (Search._pulse_status + 1) % 4; + var dotString = ''; + for (i = 0; i < Search._pulse_status; i++) + dotString += '.'; + Search.dots.text(dotString); + if (Search._pulse_status > -1) + window.setTimeout(pulse, 500); + } + pulse(); + }, + + /** + * perform a search for something (or wait until index is loaded) + */ + performSearch : function(query) { + // create the required interface elements + this.out = $('#search-results'); + this.title = $('#search-results h2:first'); // $('

' + _('Searching') + '

').appendTo(this.out); + this.dots = $('#search-results span:first'); //$('').appendTo(this.title); + this.status = $('#search-results p:first'); // $('

 

').appendTo(this.out); + this.output = $('#search-results ul:first'); //$('