From c97c4678b89ff89bda6f1ea4573803d7e1bcb947 Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 2 Dec 2022 12:48:09 -0800 Subject: [PATCH 01/58] fix: type hint of PySparkProcessor __init__ (#3297) From de589419595fbf7bf76e55745f454864cc5998be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Perez?= Date: Fri, 2 Dec 2022 22:01:39 +0100 Subject: [PATCH 02/58] fix: fix PySparkProcessor __init__ params type (#3354) From 41dd3305c2673a4f85e54eec9858f37393c89431 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:18:14 -0800 Subject: [PATCH 03/58] fix: Allow Py 3.7 for MMS Test Docker env (#3080) Co-authored-by: Mufaddal Rohawala --- tests/data/multimodel/container/Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/data/multimodel/container/Dockerfile b/tests/data/multimodel/container/Dockerfile index 4792a429c1..71c38a6605 100644 --- a/tests/data/multimodel/container/Dockerfile +++ b/tests/data/multimodel/container/Dockerfile @@ -1,4 +1,5 @@ -FROM public.ecr.aws/ubuntu/ubuntu:18.04 +# added latest image from https://gallery.ecr.aws/lts/ubuntu +FROM public.ecr.aws/ubuntu/ubuntu:22.04 # Set a docker label to advertise multi-model support on the container LABEL com.amazonaws.sagemaker.capabilities.multi-models=true @@ -15,7 +16,7 @@ RUN apt-get update && \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ - && curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py \ + && curl -O https://bootstrap.pypa.io/pip/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 From 1e23a3f6a7cf554aa537c5c4e21e35548053a6ee Mon Sep 17 00:00:00 2001 From: maldil Date: Fri, 2 Dec 2022 13:19:59 -0800 Subject: [PATCH 04/58] refactoring : using with statement (#3286) --- src/sagemaker/git_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 80bd62d5be..c424753286 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -279,9 +279,8 @@ def _run_clone_command(repo_url, dest_dir): subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) elif repo_url.startswith("git@"): with tempfile.NamedTemporaryFile() as sshnoprompt: - write_pipe = open(sshnoprompt.name, "w") - write_pipe.write("ssh -oBatchMode=yes $@") - write_pipe.close() + with open(sshnoprompt.name, "w") as write_pipe: + write_pipe.write("ssh -oBatchMode=yes $@") os.chmod(sshnoprompt.name, 0o511) my_env["GIT_SSH"] = sshnoprompt.name subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) From 19efadf043678a6c7da4122368d6141e1ec2df10 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:21:34 -0800 Subject: [PATCH 05/58] Update local_requirements.txt PyYAML version (#3095) Co-authored-by: Basil Beirouti Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> --- requirements/extras/local_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5304d82b2a..5f2c85c2fe 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==5.4.1 +PyYAML==6.0.0 From 76f7782db112b38cb7e058dffb1508f2d34fb50b Mon Sep 17 00:00:00 2001 From: arjkesh <33526713+arjkesh@users.noreply.github.com> Date: Fri, 2 Dec 2022 13:22:35 -0800 Subject: [PATCH 06/58] feature: Update TF 2.9 and TF 2.10 inference DLCs (#3465) --- .../image_uri_config/tensorflow.json | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6a01c3e3e6..0122dcd3ca 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -285,7 +285,9 @@ "2.5": "2.5.1", "2.6": "2.6.3", "2.7": "2.7.0", - "2.8": "2.8.0" + "2.8": "2.8.0", + "2.9": "2.9.2", + "2.10": "2.10.0" }, "versions": { "1.10.0": { @@ -1468,6 +1470,68 @@ "us-west-2": "763104351884" }, "repository": "tensorflow-inference" + }, + "2.9.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.10.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" } } }, From fde07388dc26cb270a0a0dfba91439c64e87751a Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Sat, 3 Dec 2022 03:41:10 +0530 Subject: [PATCH 07/58] feature: Added transform with monitoring pipeline step in transformer (#3438) Co-authored-by: Keshav Chandak --- src/sagemaker/transformer.py | 158 +++++++++++++++++++++++++++++++- tests/integ/test_transformer.py | 66 ++++++++++++- 2 files changed, 220 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index cfcc637b99..97278abdd0 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -14,14 +14,17 @@ from __future__ import absolute_import from typing import Union, Optional, List, Dict -from botocore import exceptions +import logging +import copy +import time +from botocore import exceptions from sagemaker.job import _Job -from sagemaker.session import Session +from sagemaker.session import Session, get_execution_role from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join -from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.utils import base_name_from_image, name_from_base @@ -266,6 +269,155 @@ def transform( if wait: self.latest_transform_job.wait(logs=logs) + def transform_with_monitoring( + self, + monitoring_config, + monitoring_resource_config, + data: str, + data_type: str = "S3Prefix", + content_type: str = None, + compression_type: str = None, + split_type: str = None, + input_filter: str = None, + output_filter: str = None, + join_source: str = None, + model_client_config: Dict[str, str] = None, + batch_data_capture_config: BatchDataCaptureConfig = None, + monitor_before_transform: bool = False, + supplied_baseline_statistics: str = None, + supplied_baseline_constraints: str = None, + wait: bool = True, + pipeline_name: str = None, + role: str = None, + ): + """Runs a transform job with monitoring job. + + Note that this function will not start a transform job immediately, + instead, it will create a SageMaker Pipeline and execute it. + If you provide an existing pipeline_name, no new pipeline will be created, otherwise, + each transform_with_monitoring call will create a new pipeline and execute. + + Args: + monitoring_config (Union[ + `sagemaker.workflow.quality_check_step.QualityCheckConfig`, + `sagemaker.workflow.quality_check_step.ClarifyCheckConfig` + ]): the monitoring configuration used for run model monitoring. + monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): + the check job (processing job) cluster resource configuration. + transform_step_args (_JobStepArguments): the transform step transform arguments. + data (str): Input data location in S3 for the transform job + data_type (str): What the S3 location defines (default: 'S3Prefix'). + Valid values: + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix + will be used as inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 + object to use as an input for the transform job. + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if + compressed (default: None). Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object + (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and + 'TFRecord'. + input_filter (str): A JSONPath to select a portion of the input to + pass to the algorithm container for inference. If you omit the + field, it gets the value '$', representing the entire input. + For CSV data, each row is taken as a JSON array, + so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. + CSV data should follow the `RFC format `_. + See `Supported JSONPath Operators + `_ + for a table of supported JSONPath operators. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.features" (default: None). + output_filter (str): A JSONPath to select a portion of the + joined/original output to return as the output. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.prediction" (default: None). + join_source (str): The source of data to be joined to the transform + output. It can be set to 'Input' meaning the entire input record + will be joined to the inference result. You can use OutputFilter + to select the useful portion before uploading to S3. (default: + None). Valid values: Input, None. + model_client_config (dict[str, str]): Model configuration. + Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + (default: ``None``). + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: ``None``). + monitor_before_transform (bgool): If to run data quality + or model explainability monitoring type, + a true value of this flag indicates running the check step before the transform job. + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the + check step when a violation is detected. + supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path + to the supplied statistics object representing the statistics JSON file + which will be used for drift to check (default: None). + supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path + to the supplied constraints object representing the constraints JSON file + which will be used for drift to check (default: None). + wait (bool): To determine if needed to wait for the pipeline execution to complete + pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step + role (str): Execution role + """ + + transformer = self + if not isinstance(self.sagemaker_session, PipelineSession): + sagemaker_session = self.sagemaker_session + self.sagemaker_session = None + transformer = copy.deepcopy(self) + transformer.sagemaker_session = PipelineSession() + self.sagemaker_session = sagemaker_session + + transform_step_args = transformer.transform( + data=data, + data_type=data_type, + content_type=content_type, + compression_type=compression_type, + split_type=split_type, + input_filter=input_filter, + output_filter=output_filter, + batch_data_capture_config=batch_data_capture_config, + join_source=join_source, + model_client_config=model_client_config, + ) + + from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep + + monitoring_batch_step = MonitorBatchTransformStep( + name="MonitorBatchTransformStep", + display_name="MonitorBatchTransformStep", + description="", + transform_step_args=transform_step_args, + monitor_configuration=monitoring_config, + check_job_configuration=monitoring_resource_config, + monitor_before_transform=monitor_before_transform, + supplied_baseline_constraints=supplied_baseline_constraints, + supplied_baseline_statistics=supplied_baseline_statistics, + ) + + pipeline_name = ( + pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}" + ) + # if pipeline exists, just start the execution + from sagemaker.workflow.pipeline import Pipeline + + pipeline = Pipeline( + name=pipeline_name, + steps=[monitoring_batch_step], + sagemaker_session=transformer.sagemaker_session, + ) + pipeline.upsert(role_arn=role if role else get_execution_role()) + execution = pipeline.start() + if wait: + logging.info("Waiting for transform with monitoring to execute ...") + execution.wait() + return execution + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index a0e37ffc77..1de333b987 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -25,6 +25,7 @@ from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator from sagemaker.inputs import BatchDataCaptureConfig +from sagemaker.xgboost import XGBoostModel from sagemaker.utils import unique_name_from_base from tests.integ import ( datasets, @@ -36,7 +37,7 @@ from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources -from sagemaker.model_monitor import DatasetFormat, Statistics +from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints from sagemaker.workflow.check_job_config import CheckJobConfig from sagemaker.workflow.quality_check_step import ( @@ -645,3 +646,66 @@ def _create_transformer_and_transform_job( job_name=unique_name_from_base("test-transform"), ) return transformer + + +def test_transformer_and_monitoring_job( + pipeline_session, + sagemaker_session, + role, + pipeline_name, + check_job_config, + data_bias_check_config, +): + xgb_model_data_s3 = pipeline_session.upload_data( + path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + data_bias_supplied_baseline_constraints = Constraints.from_file_path( + constraints_file_path=os.path.join( + DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + xgb_model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + role=role, + sagemaker_session=sagemaker_session, + entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"), + enable_network_isolation=True, + ) + + xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE) + + transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform" + transformer = Transformer( + model_name=xgb_model.name, + strategy="SingleRecord", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=transform_output, + sagemaker_session=pipeline_session, + ) + + transform_input = pipeline_session.upload_data( + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", + ) + + execution = transformer.transform_with_monitoring( + monitoring_config=data_bias_check_config, + monitoring_resource_config=check_job_config, + data=transform_input, + content_type="text/libsvm", + supplied_baseline_constraints=data_bias_supplied_baseline_constraints, + role=role, + ) + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + + for execution_step in execution_steps: + assert execution_step["StepStatus"] == "Succeeded" + + xgb_model.delete_model() From 7f9f3b04b6704a4d2378b5d9aa3d37de9db45729 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Fri, 2 Dec 2022 17:12:34 -0500 Subject: [PATCH 08/58] fix: Fix bug forcing uploaded tar to be named sourcedir (#3412) --- src/sagemaker/processing.py | 19 +++++++++++-------- tests/integ/test_xgboost.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index db6ce2badd..308783578d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1587,13 +1587,13 @@ def run( # type: ignore[override] framework script to run.Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. When `code` is an S3 URI, ignore `source_dir`, - `dependencies, and `git_config`. If ``source_dir`` is specified, + `dependencies`, and `git_config`. If ``source_dir`` is specified, then ``code`` must point to a file located at the root of ``source_dir``. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other processing source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when processing on Amazon SageMaker (default: None). + point to a file named `sourcedir.tar.gz`. Structure within this directory + are preserved when processing on Amazon SageMaker (default: None). dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be @@ -1730,12 +1730,15 @@ def _pack_and_upload_code( "sagemaker_session unspecified when creating your Processor to have one set up " "automatically." ) + if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix: + # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. + entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( + "sourcedir.tar.gz", + "runproc.sh", + ) + else: + raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") - # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. - entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( - "sourcedir.tar.gz", - "runproc.sh", - ) script = estimator.uploaded_code.script_name s3_runproc_sh = S3Uploader.upload_string_as_file_body( self._generate_framework_script(script), diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py index 733ab4665a..df06a8863a 100644 --- a/tests/integ/test_xgboost.py +++ b/tests/integ/test_xgboost.py @@ -40,6 +40,26 @@ def xgboost_training_job( ) +def test_sourcedir_naming( + sagemaker_session, + xgboost_latest_version, + xgboost_latest_py_version, + cpu_instance_type, +): + with pytest.raises(RuntimeError): + processor = XGBoostProcessor( + framework_version=xgboost_latest_version, + role=ROLE, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + processor.run( + source_dir="s3://bucket/deps.tar.gz", + code="main_script.py", + ) + + @pytest.mark.release def test_framework_processing_job_with_deps( sagemaker_session, From 5d5976726cb8e0cf7143d86b4abb4b665842fd14 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Fri, 2 Dec 2022 14:32:01 -0800 Subject: [PATCH 09/58] feature: Add Code Owners file (#3503) Co-authored-by: Navin Soni --- CODEOWNERS | 1 + requirements/extras/local_requirements.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000..7f7ac28644 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @aws/sagemaker-ml-frameworks diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5f2c85c2fe..5304d82b2a 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==6.0.0 +PyYAML==5.4.1 From 0f5cf1824c0b116c9b218c803f3b94a85e09fd45 Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:39 +0000 Subject: [PATCH 10/58] prepare release v2.119.0 --- CHANGELOG.md | 28 ++++++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95e4a7b9cf..b8b3155231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## v2.119.0 (2022-12-03) + +### Features + + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * fix PySparkProcessor __init__ params type + * type hint of PySparkProcessor __init__ + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + ## v2.118.0 (2022-12-01) ### Features diff --git a/VERSION b/VERSION index 34d47b7f52..23fe2bf317 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.118.1.dev0 +2.119.0 From f1f0013dc0375aa22805b3a59b82cd2b1a08d40a Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:41 +0000 Subject: [PATCH 11/58] update development version to v2.119.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 23fe2bf317..dda4128cf2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.0 +2.119.1.dev0 From bb4b6897971a4e5ae0cbde948ef1682a64232b41 Mon Sep 17 00:00:00 2001 From: Radhika Bhat <78102284+RadhikaB-97@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:06:58 -0800 Subject: [PATCH 12/58] feature: Add DXB region to frameworks by DLC (#3387) * Add DXB region * Remove change from neuron * Adding DXB to TF 2.1.0 and 2.1.1 --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../huggingface-training-compiler.json | 3 + .../image_uri_config/huggingface.json | 31 +++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ src/sagemaker/image_uri_config/pytorch.json | 28 ++++++++ .../image_uri_config/tensorflow.json | 65 +++++++++++++++++++ 6 files changed, 152 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3cc488c55d..0963520e02 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -26,6 +26,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -56,6 +57,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -86,6 +88,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -116,6 +119,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -146,6 +150,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -176,6 +181,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -217,6 +223,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -250,6 +257,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -283,6 +291,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -316,6 +325,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -349,6 +359,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -382,6 +393,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index e771e2a548..482264b773 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -60,6 +60,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -89,6 +90,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -123,6 +125,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 317c17030a..e995c6e8ea 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -38,6 +38,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -70,6 +71,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -108,6 +110,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -140,6 +143,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -180,6 +184,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -213,6 +218,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -246,6 +252,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -279,6 +286,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -320,6 +328,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -353,6 +362,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -386,6 +396,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -419,6 +430,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -458,6 +470,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -491,6 +504,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -530,6 +544,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -563,6 +578,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -602,6 +618,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -635,6 +652,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -687,6 +705,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -720,6 +739,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -753,6 +773,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -794,6 +815,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -827,6 +849,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -860,6 +883,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -893,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -932,6 +957,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +991,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1004,6 +1031,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1037,6 +1065,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1076,6 +1105,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1109,6 +1139,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 12bc40fccf..14bb74f6a6 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -245,6 +245,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -277,6 +278,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -309,6 +311,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -341,6 +344,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -373,6 +377,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -632,6 +637,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -664,6 +670,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -696,6 +703,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -728,6 +736,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -760,6 +769,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -865,6 +875,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -897,6 +908,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -929,6 +941,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 3bf8016ba8..e1de6ca663 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -195,6 +195,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -230,6 +231,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -264,6 +266,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -298,6 +301,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -333,6 +337,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -368,6 +373,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -403,6 +409,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -438,6 +445,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -472,6 +480,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -506,6 +515,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -540,6 +550,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -574,6 +585,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -608,6 +620,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -642,6 +655,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -879,6 +893,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -914,6 +929,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -949,6 +965,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -983,6 +1000,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1018,6 +1036,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1053,6 +1072,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1088,6 +1108,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1123,6 +1144,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1157,6 +1179,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1191,6 +1214,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1225,6 +1249,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1259,6 +1284,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1293,6 +1319,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1327,6 +1354,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 0122dcd3ca..bb05682f67 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -154,6 +154,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -185,6 +186,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -216,6 +218,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -247,6 +250,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -401,6 +405,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -432,6 +437,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -463,6 +469,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -494,6 +501,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -525,6 +533,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -556,6 +565,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -587,6 +597,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -810,6 +821,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -841,6 +853,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -872,6 +885,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -903,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -934,6 +949,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -996,6 +1013,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1027,6 +1045,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1058,6 +1077,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1089,6 +1109,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1120,6 +1141,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1151,6 +1173,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1182,6 +1205,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1213,6 +1237,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1244,6 +1269,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1275,6 +1301,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1306,6 +1333,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1337,6 +1365,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1368,6 +1397,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1399,6 +1429,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1430,6 +1461,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1461,6 +1493,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1760,6 +1793,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1796,6 +1830,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1831,6 +1866,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1867,6 +1903,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1903,6 +1940,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1939,6 +1977,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1975,6 +2014,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2202,6 +2242,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2237,6 +2278,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2272,6 +2314,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2306,6 +2349,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2340,6 +2384,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2375,6 +2420,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2410,6 +2456,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2444,6 +2491,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2478,6 +2526,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2512,6 +2561,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2546,6 +2596,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2580,6 +2631,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2614,6 +2666,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2648,6 +2701,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2682,6 +2736,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2716,6 +2771,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2750,6 +2806,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2784,6 +2841,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2818,6 +2876,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2852,6 +2911,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2886,6 +2946,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2920,6 +2981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2954,6 +3016,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2988,6 +3051,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3022,6 +3086,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", From b68bcd9344deba8e3bedf7ccb0adb31498735b13 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Mon, 5 Dec 2022 14:11:34 -0800 Subject: [PATCH 13/58] fix: support idempotency for framework and spark processors (#3460) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 8 +- src/sagemaker/spark/processing.py | 37 +- src/sagemaker/workflow/utilities.py | 7 +- tests/data/spark/code/java/TestJarFile.jar | Bin 0 -> 1714 bytes .../hello-java-spark/HelloJavaSparkApp.jar | Bin 0 -> 1714 bytes .../unit/sagemaker/workflow/test_pipeline.py | 8 +- .../workflow/test_processing_step.py | 277 +++++++++++++- .../sagemaker/workflow/test_training_step.py | 354 +++++++++++++++--- .../sagemaker/workflow/test_transform_step.py | 8 + .../sagemaker/workflow/test_tuning_step.py | 58 +-- 10 files changed, 661 insertions(+), 96 deletions(-) create mode 100644 tests/data/spark/code/java/TestJarFile.jar create mode 100644 tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 308783578d..81e3d34b1d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -23,6 +23,7 @@ import logging from textwrap import dedent from typing import Dict, List, Optional, Union +from copy import copy import attr @@ -1830,14 +1831,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput # a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426 if inputs is None: inputs = [] - inputs.append( + + # make a shallow copy of user inputs + patched_inputs = copy(inputs) + patched_inputs.append( ProcessingInput( input_name="code", source=s3_payload, destination="/opt/ml/processing/input/code/", ) ) - return inputs + return patched_inputs def _set_entrypoint(self, command, user_script_name): """Framework processor override for setting processing job entrypoint. diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index dc3d26a355..912bc90d80 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -30,6 +30,7 @@ from enum import Enum from io import BytesIO from urllib.parse import urlparse +from copy import copy from typing import Union, List, Dict, Optional @@ -279,6 +280,10 @@ def run( def _extend_processing_args(self, inputs, outputs, **kwargs): """Extends processing job args such as inputs.""" + # make a shallow copy of user outputs + outputs = outputs or [] + extended_outputs = copy(outputs) + if kwargs.get("spark_event_logs_s3_uri"): spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri") self._validate_s3_uri(spark_event_logs_s3_uri) @@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): s3_upload_mode="Continuous", ) - outputs = outputs or [] - outputs.append(output) + extended_outputs.append(output) + + # make a shallow copy of user inputs + inputs = inputs or [] + extended_inputs = copy(inputs) if kwargs.get("configuration"): configuration = kwargs.get("configuration") self._validate_configuration(configuration) - inputs = inputs or [] - inputs.append(self._stage_configuration(configuration)) + extended_inputs.append(self._stage_configuration(configuration)) - return inputs, outputs + return ( + extended_inputs if extended_inputs else None, + extended_outputs if extended_outputs else None, + ) def start_history_server(self, spark_event_logs_s3_uri=None): """Starts a Spark history server. @@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): outputs: Processing outputs. kwargs: Additional keyword arguments passed to `super()`. """ + + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + self.command = [_SparkProcessorBase._default_command] extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_py_files"), FileType.PYTHON + extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_jars"), FileType.JAR @@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): else: raise ValueError("submit_class is required") + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_jars"), FileType.JAR + extended_inputs, kwargs.get("submit_jars"), FileType.JAR ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_files"), FileType.FILE diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 89d7c5dfd9..08c170d424 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str: if isinstance(step, ProcessingStep) and step.step_args: kwargs = step.step_args.func_kwargs source_dir = kwargs.get("source_dir") + submit_class = kwargs.get("submit_class") dependencies = get_processing_dependencies( [ kwargs.get("dependencies"), kwargs.get("submit_py_files"), - kwargs.get("submit_class"), + [submit_class] if submit_class else None, kwargs.get("submit_jars"), kwargs.get("submit_files"), ] @@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str] str: A hash string representing the unique code artifact(s) for the step """ - # FrameworkProcessor + # If FrameworkProcessor contains source_dir if source_dir: source_dir_url = urlparse(source_dir) if source_dir_url.scheme == "" or source_dir_url.scheme == "file": @@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments): """ chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs) - if chained_args: + if isinstance(chained_args, _StepArguments): execute_job_functions(chained_args) diff --git a/tests/data/spark/code/java/TestJarFile.jar b/tests/data/spark/code/java/TestJarFile.jar new file mode 100644 index 0000000000000000000000000000000000000000..d528331d557da00908e31c46b2a0dd3dc250a2bf GIT binary patch literal 1714 zcmWIWW@Zs#;Nak32&_&EWk3R)3@i-3t|5-Po_=on|4uP5Ff#;rvvYt{FhP|C;M6Pv zQ~}rQ>*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I Date: Mon, 5 Dec 2022 18:18:10 -0600 Subject: [PATCH 14/58] feature: Update registries with new region account number mappings. (#3492) --- src/sagemaker/image_uri_config/autogluon.json | 18 ++++ .../image_uri_config/huggingface-neuron.json | 3 + .../image_uri_config/huggingface.json | 39 +++++++ src/sagemaker/image_uri_config/mxnet.json | 24 +++++ src/sagemaker/image_uri_config/pytorch.json | 54 ++++++++++ .../image_uri_config/tensorflow.json | 102 ++++++++++++++++++ 6 files changed, 240 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 0963520e02..3a9f02142c 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -210,6 +210,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -217,11 +218,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -244,6 +247,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -251,11 +255,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -278,6 +284,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -285,11 +292,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -312,6 +321,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -319,11 +329,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -346,6 +358,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -353,11 +366,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -380,6 +395,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -387,11 +403,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 1e2246cb11..47d6dbd1dc 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -15,17 +15,20 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index e995c6e8ea..5b98fc0d02 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -692,6 +692,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -699,11 +700,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -726,6 +729,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -733,11 +737,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -760,6 +766,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -767,8 +774,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -802,6 +811,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -809,11 +819,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -836,6 +848,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -843,11 +856,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -870,6 +885,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -877,8 +893,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +922,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +930,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -944,6 +965,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -951,11 +973,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -978,6 +1002,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -985,8 +1010,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1018,6 +1045,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1025,11 +1053,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1052,6 +1082,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1059,8 +1090,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1092,6 +1125,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1099,11 +1133,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1126,6 +1162,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1133,8 +1170,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 14bb74f6a6..8d8733e480 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -624,6 +624,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -631,11 +632,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -657,6 +660,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -664,11 +668,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -690,6 +696,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -697,11 +704,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -723,6 +732,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -730,11 +740,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -756,6 +768,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -763,11 +776,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -862,6 +877,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -869,11 +885,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -895,6 +913,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -902,11 +921,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -928,6 +949,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -935,11 +957,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index e1de6ca663..18a382e591 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -17,6 +17,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -25,7 +26,9 @@ "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-north-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -39,8 +42,11 @@ "registries": { "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-3": "907027046896", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -182,6 +188,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -189,11 +196,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -218,6 +227,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -225,11 +235,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -253,6 +265,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -260,11 +273,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -288,6 +303,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -295,11 +311,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -324,6 +342,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -331,11 +350,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -360,6 +381,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -367,11 +389,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -396,6 +420,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -403,11 +428,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -432,6 +459,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -439,11 +467,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -467,6 +497,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -474,11 +505,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -502,6 +535,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -509,11 +543,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -537,6 +573,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -544,11 +581,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -572,6 +611,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -579,11 +619,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -607,6 +649,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -614,11 +657,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -642,6 +687,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -649,11 +695,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -677,6 +725,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -684,11 +733,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -721,6 +772,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -728,11 +780,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index bb05682f67..a0f2bba014 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -141,6 +141,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -148,12 +149,14 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "eu-south-2": "503227376785", "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -173,6 +176,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -180,8 +184,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -205,6 +211,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -212,8 +219,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -237,6 +246,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -244,8 +254,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -392,6 +404,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -399,8 +412,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -424,6 +439,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -431,8 +447,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -456,6 +474,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -463,8 +482,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -488,6 +509,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -495,8 +517,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -520,6 +544,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -527,8 +552,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -552,6 +579,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -559,8 +587,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -584,6 +614,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -591,8 +622,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -808,6 +841,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -815,8 +849,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -840,6 +876,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -847,8 +884,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -872,6 +911,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -879,8 +919,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +946,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +954,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -936,6 +981,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -943,8 +989,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -968,6 +1016,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -975,8 +1024,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1000,6 +1051,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1007,8 +1059,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1032,6 +1086,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1039,8 +1094,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1064,6 +1121,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1071,8 +1129,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1096,6 +1156,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1103,8 +1164,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1128,6 +1191,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1135,8 +1199,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1160,6 +1226,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1167,8 +1234,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1192,6 +1261,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1199,8 +1269,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1224,6 +1296,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1231,8 +1304,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1256,6 +1331,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1263,8 +1339,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1288,6 +1366,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1295,8 +1374,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1320,6 +1401,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1327,8 +1409,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1352,6 +1436,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1359,8 +1444,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1384,6 +1471,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1391,8 +1479,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1416,6 +1506,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1423,8 +1514,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1448,6 +1541,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1455,8 +1549,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1480,6 +1576,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1487,8 +1584,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1587,6 +1686,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1594,11 +1694,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", From 767da0afc5cfb11eb96b324debf9a310abaafbcc Mon Sep 17 00:00:00 2001 From: Loki Date: Wed, 7 Dec 2022 06:06:34 +0530 Subject: [PATCH 15/58] feature: Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 (#3500) Co-authored-by: Ubuntu --- src/sagemaker/fw_utils.py | 2 +- .../pytorch-training-compiler.json | 41 ++ src/sagemaker/image_uris.py | 2 +- src/sagemaker/pytorch/__init__.py | 2 + src/sagemaker/pytorch/estimator.py | 60 +- .../pytorch/training_compiler/__init__.py | 0 .../pytorch/training_compiler/config.py | 151 +++++ tests/conftest.py | 1 + tests/data/huggingface_byoc/requirements.txt | 2 + tests/data/huggingface_byoc/run_glue.py | 568 ++++++++++++++++ tests/data/huggingface_byoc/train/dummy.csv | 1 + tests/integ/__init__.py | 2 +- tests/integ/test_training_compiler.py | 50 +- .../test_pytorch_compiler.py | 616 ++++++++++++++++++ 14 files changed, 1467 insertions(+), 31 deletions(-) create mode 100644 src/sagemaker/image_uri_config/pytorch-training-compiler.json create mode 100644 src/sagemaker/pytorch/training_compiler/__init__.py create mode 100644 src/sagemaker/pytorch/training_compiler/config.py create mode 100644 tests/data/huggingface_byoc/requirements.txt create mode 100644 tests/data/huggingface_byoc/run_glue.py create mode 100644 tests/data/huggingface_byoc/train/dummy.csv create mode 100644 tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index d82d3596ac..5efe530396 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -493,7 +493,7 @@ def framework_name_from_image(image_uri): # We must support both the legacy and current image name format. name_pattern = re.compile( r"""^(?:sagemaker(?:-rl)?-)? - (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost + (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost |huggingface-tensorflow|huggingface-pytorch |huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)? (scriptmode|training)? diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json new file mode 100644 index 0000000000..892ff4237d --- /dev/null +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -0,0 +1,41 @@ +{ + "training": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.12": "1.12.0" + }, + "versions": { + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + } + } + } +} diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 7d1d3bd835..c42ce02188 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -146,7 +146,7 @@ def retrieve( tolerate_deprecated_model, ) - if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK): + if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): final_image_scope = image_scope config = _config_for_framework_and_scope( framework + "-training-compiler", final_image_scope, accelerator_type diff --git a/src/sagemaker/pytorch/__init__.py b/src/sagemaker/pytorch/__init__.py index cac5f94b9a..e2d14f4163 100644 --- a/src/sagemaker/pytorch/__init__.py +++ b/src/sagemaker/pytorch/__init__.py @@ -16,3 +16,5 @@ from sagemaker.pytorch.estimator import PyTorch # noqa: F401 from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401 from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401 + +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401 diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 686de4a78c..29e254662f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -28,6 +28,7 @@ ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable @@ -51,7 +52,8 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, - **kwargs + compiler_config: Optional[TrainingCompilerConfig] = None, + **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -208,6 +210,31 @@ def __init__( To learn more, see `Training with parameter servers `_. + **To enable distributed training with + `SageMaker Training Compiler `_ + for PyTorch:** + + .. code:: python + + { + "pytorchxla": { + "enabled": True + } + } + + To learn more, see `SageMaker Training Compiler + `_ + in the *Amazon SageMaker Developer Guide*. + + .. note:: + + When you use this PyTorch XLA option for distributed training strategy, + you must add the ``compiler_config`` parameter and activate SageMaker + Training Compiler. + + compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): + Configures SageMaker Training Compiler to accelerate training. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -250,6 +277,25 @@ def __init__( self.distribution = distribution or {} + if compiler_config is not None: + if not isinstance(compiler_config, TrainingCompilerConfig): + error_string = ( + f"Expected instance of type {TrainingCompilerConfig}" + f"for argument compiler_config. " + f"Instead got {type(compiler_config)}" + ) + raise ValueError(error_string) + if compiler_config: + compiler_config.validate(self) + elif distribution is not None and "pytorchxla" in distribution: + raise ValueError( + "Distributed training through PyTorch XLA is currently only supported " + "when SageMaker Training Compiler is enabled. To learn more, " + "see Enable SageMaker Training Compiler at " + "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html." + ) + self.compiler_config = compiler_config + def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training @@ -289,6 +335,12 @@ def hyperparameters(self): hyperparameters.update( EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) ) + if self.compiler_config: + training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) + ) + return hyperparameters def create_model( @@ -299,7 +351,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. @@ -350,7 +402,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod @@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) image_uri = init_params.pop("image_uri") framework, py_version, tag, _ = framework_name_from_image(image_uri) + if framework: + framework = framework.split("-")[0] if tag is None: framework_version = None diff --git a/src/sagemaker/pytorch/training_compiler/__init__.py b/src/sagemaker/pytorch/training_compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/pytorch/training_compiler/config.py b/src/sagemaker/pytorch/training_compiler/config.py new file mode 100644 index 0000000000..7faf8acbbd --- /dev/null +++ b/src/sagemaker/pytorch/training_compiler/config.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Configuration for the SageMaker Training Compiler.""" +from __future__ import absolute_import +import logging +from typing import Union +from packaging.specifiers import SpecifierSet +from packaging.version import Version + +from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable + +logger = logging.getLogger(__name__) + + +class TrainingCompilerConfig(BaseConfig): + """The SageMaker Training Compiler configuration class.""" + + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + ] + + def __init__( + self, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, + ): + """This class initializes a ``TrainingCompilerConfig`` instance. + + `Amazon SageMaker Training Compiler + `_ + is a feature of SageMaker Training + and speeds up training jobs by optimizing model execution graphs. + + You can compile PyTorch models + by passing the object of this configuration class to the ``compiler_config`` + parameter of the :class:`~sagemaker.pytorch.PyTorch` + estimator. + + Args: + enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker + Training Compiler. The default is ``True``. + debug (bool or PipelineVariable): Optional. Whether to dump detailed logs + for debugging. This comes with a potential performance slowdown. + The default is ``False``. + + **Example**: The following code shows the basic usage of the + :class:`sagemaker.pytorch.TrainingCompilerConfig()` class + to run a PyTorch training job with the compiler. + + .. code-block:: python + + from sagemaker.pytorch import PyTorch, TrainingCompilerConfig + + pytorch_estimator=PyTorch( + ... + compiler_config=TrainingCompilerConfig() + ) + + .. seealso:: + + For more information about how to enable SageMaker Training Compiler + for various training settings such as distributed training, + see `Enable SageMaker Training Compiler + `_ + in the `Amazon SageMaker Training Compiler developer guide + `_. + + """ + + super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) + + @classmethod + def validate( + cls, + estimator, + ): + """Checks if SageMaker Training Compiler is configured correctly. + + Args: + estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object. + If SageMaker Training Compiler is enabled, it will validate whether + the estimator is configured to be compatible with Training Compiler. + + Raises: + ValueError: Raised if the requested configuration is not compatible + with SageMaker Training Compiler. + """ + + super(TrainingCompilerConfig, cls).validate(estimator) + + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the PyTorch training job " + "with SageMaker Training Compiler enabled: " + "framework_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if estimator.distribution: + pt_xla_present = "pytorchxla" in estimator.distribution + pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) + if pt_xla_enabled: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet("< 1.12"): + error_helper_string = ( + "Distribution mechanism 'pytorchxla' is currently only supported for " + "PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received framework_version={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.framework_version)) + if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA: + logger.warning( + "Consider using instances with EFA support when " + "training with PyTorch >= 1.12 and SageMaker Training Compiler " + "enabled. SageMaker Training Compiler leverages EFA to provide better " + "performance for distributed training." + ) + if not pt_xla_present: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + error_helper_string = ( + "'pytorchxla' is the only distribution mechanism currently supported " + "for PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received distribution={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.distribution)) + elif estimator.instance_count and estimator.instance_count > 1: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + logger.warning( + "Consider setting 'distribution' to 'pytorchxla' for distributed " + "training with PyTorch >= 1.12 and SageMaker Training Compiler enabled." + ) diff --git a/tests/conftest.py b/tests/conftest.py index e92d98112b..f6682ebb8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ "neo_pytorch", "neo_tensorflow", "pytorch", + "pytorch_training_compiler", "ray_pytorch", "ray_tensorflow", "sklearn", diff --git a/tests/data/huggingface_byoc/requirements.txt b/tests/data/huggingface_byoc/requirements.txt new file mode 100644 index 0000000000..462542f1c1 --- /dev/null +++ b/tests/data/huggingface_byoc/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py new file mode 100644 index 0000000000..1060398fa4 --- /dev/null +++ b/tests/data/huggingface_byoc/run_glue.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.5.0") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the test data."} + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError( + "Unknown task, you should pick one in " + ",".join(task_to_keys.keys()) + ) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [ + name for name in datasets["train"].column_names if name != "label" + ] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) + if sentence2_key is None + else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + datasets = datasets.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in datasets and "validation_matched" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets[ + "validation_matched" if data_args.task_name == "mnli" else "validation" + ] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + + if ( + training_args.do_predict + or data_args.task_name is not None + or data_args.test_file is not None + ): + if "test" not in datasets and "test_matched" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_val_samples = ( + data_args.max_val_samples + if data_args.max_val_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = ( + np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + ) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/data/huggingface_byoc/train/dummy.csv b/tests/data/huggingface_byoc/train/dummy.csv new file mode 100644 index 0000000000..fb1539d552 --- /dev/null +++ b/tests/data/huggingface_byoc/train/dummy.csv @@ -0,0 +1 @@ +# dummy data \ No newline at end of file diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 00ed09577b..9133fc8904 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -158,7 +158,7 @@ "ap-northeast-1", "eu-central-1", ] -# TODO: SM Training Compiler team to add all supported regions. + TRAINING_COMPILER_SUPPORTED_REGIONS = [ "af-south-1", "ap-east-1", diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 67de050ed1..724cd8890c 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -20,6 +20,8 @@ from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig from sagemaker.tensorflow import TensorFlow from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch import TrainingCompilerConfig as PTTrainingCompilerConfig from tests import integ from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES @@ -48,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory): key_prefix="Imagenet/TFRecords/validation", ) train_input = sagemaker_session.upload_data( - path=local_path, - key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val", + path=local_path, key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val" ) return train_input @@ -84,8 +85,8 @@ def skip_if_incompatible(gpu_instance_type, request): @pytest.mark.parametrize( "gpu_instance_type,instance_count", [ - ("ml.p3.2xlarge", 1), - ("ml.p3.16xlarge", 2), + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), ], ) def test_huggingface_pytorch( @@ -129,27 +130,32 @@ def test_huggingface_pytorch( hf.fit(huggingface_dummy_dataset) -@pytest.mark.release -def test_huggingface_pytorch_release( +@pytest.mark.parametrize( + "gpu_instance_type,instance_count", + [ + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), + ], +) +def test_pytorch( sagemaker_session, gpu_instance_type, - huggingface_training_compiler_latest_version, - huggingface_training_compiler_pytorch_latest_version, + instance_count, + pytorch_training_compiler_latest_version, huggingface_dummy_dataset, ): """ - Test the HuggingFace estimator with PyTorch + Test the PyTorch estimator """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, "huggingface") - hf = HuggingFace( + hf = PyTorch( py_version="py38", - entry_point=os.path.join(data_path, "run_glue.py"), + source_dir=os.path.join(DATA_DIR, "huggingface_byoc"), + entry_point="run_glue.py", role="SageMakerRole", - transformers_version=huggingface_training_compiler_latest_version, - pytorch_version=huggingface_training_compiler_pytorch_latest_version, - instance_count=1, + framework_version=pytorch_training_compiler_latest_version, + instance_count=instance_count, instance_type=gpu_instance_type, hyperparameters={ "model_name_or_path": "distilbert-base-cased", @@ -163,7 +169,8 @@ def test_huggingface_pytorch_release( }, sagemaker_session=sagemaker_session, disable_profiler=True, - compiler_config=HFTrainingCompilerConfig(), + compiler_config=PTTrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None, ) hf.fit(huggingface_dummy_dataset) @@ -209,10 +216,7 @@ def test_huggingface_tensorflow( @pytest.mark.release def test_tensorflow( - sagemaker_session, - gpu_instance_type, - tensorflow_training_latest_version, - imagenet_val_set, + sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set ): """ Test the TensorFlow estimator @@ -264,8 +268,4 @@ def test_tensorflow( compiler_config=TFTrainingCompilerConfig(), ) - tf.fit( - inputs=imagenet_val_set, - logs=True, - wait=True, - ) + tf.fit(inputs=imagenet_val_set, logs=True, wait=True) diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py new file mode 100644 index 0000000000..0fe2402695 --- /dev/null +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -0,0 +1,616 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import logging + +import json +import os + +import pytest +from mock import MagicMock, Mock, patch, ANY +from packaging.version import Version + +from sagemaker import image_uris +from sagemaker.pytorch import PyTorch, TrainingCompilerConfig +from sagemaker.pytorch.model import PyTorchModel +from sagemaker.instance_group import InstanceGroup + +from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" +MODEL_DATA = "s3://some/data.tar.gz" +ENV = {"DUMMY_ENV_VAR": "dummy_value"} +TIMESTAMP = "2017-11-06-14:14:15.672" +TIME = 1510006209.073025 +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.p3.2xlarge" +IMAGE_URI = "pytorch" +JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) +ROLE = "Dummy" +REGION = "us-east-1" +GPU = "ml.p3.2xlarge" +SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"} +UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} + +EXPERIMENT_CONFIG = { + "ExperimentName": "exp", + "TrialName": "trial", + "TrialComponentDisplayName": "tc", +} + + +@pytest.fixture(scope="module") +def cpu_instance_type(): + return "ml.m5.xlarge" + + +@pytest.fixture(name="sagemaker_session", scope="function") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + return session + + +def _get_full_gpu_image_uri(version, instance_type, training_compiler_config): + return image_uris.retrieve( + "pytorch-training-compiler", + REGION, + version=version, + py_version="py38", + instance_type=instance_type, + image_scope="training", + container_version=None, + training_compiler_config=training_compiler_config, + ) + + +def _create_train_job(version, instance_type, training_compiler_config, instance_count=1): + return { + "image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, + } + ], + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": instance_type, + "InstanceCount": instance_count, + "VolumeSizeInGB": 30, + }, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-east-1"', + }, + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, + "environment": None, + "retry_strategy": None, + "experiment_config": EXPERIMENT_CONFIG, + "debugger_hook_config": { + "CollectionConfigurations": [], + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, + "profiler_rule_configs": [ + { + "RuleConfigurationName": "ProfilerReport-1510006209", + "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", + "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, + } + ], + "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + } + + +def test_unsupported_BYOC( + pytorch_training_compiler_version, +): + byoc = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + with pytest.raises(ValueError): + PyTorch( + image_uri=byoc, + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=cpu_instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +def test_unsupported_gpu_instance( + unsupported_gpu_instance_class, pytorch_training_compiler_version +): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=f"ml.{unsupported_gpu_instance_class}.xlarge", + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.") +def test_unsupported_framework_version(): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="99.99.99", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_python_2( + pytorch_training_compiler_version, +): + with pytest.raises(ValueError): + PyTorch( + py_version="py27", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_instance_group( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_groups=[ + InstanceGroup("ml.p3dn.24xlarge", "ml.p3dn.24xlarge", 16), + InstanceGroup("ml.p4d.24xlarge", "ml.p4d.24xlarge", 16), + ], + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_distribution( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + transformers_version="4.17", + pytorch_version="1.10", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"mpi": {"enabled": True}}, + ).fit() + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_pytorchxla_distribution( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=2, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2 + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_default_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_debug_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(debug=True) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + True + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_disable_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(enabled=False) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(enabled=False), + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + False + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@pytest.mark.parametrize( + ["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)] +) +def test_attach(sagemaker_session, compiler_enabled, debug_enabled): + training_image = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"trcomp"', + "training_steps": "100", + "sagemaker_region": '"us-east-1"', + TrainingCompilerConfig.HP_ENABLE_COMPILER: json.dumps(compiler_enabled), + TrainingCompilerConfig.HP_ENABLE_DEBUG: json.dumps(debug_enabled), + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.p3.2xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "trcomp", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "trcomp" + assert estimator.py_version == "py38" + assert estimator.framework_version == "1.12.0" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" + assert estimator.instance_count == 1 + assert estimator.max_run == 24 * 60 * 60 + assert estimator.input_mode == "File" + assert estimator.base_job_name == "trcomp" + assert estimator.output_path == "s3://place/output/trcomp" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps( + compiler_enabled + ) + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps( + debug_enabled + ) + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_pytorch_model_auto_infer_framework( + sagemaker_session, pytorch_training_compiler_version +): + + model_package_group_name = "test-pt-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + pt_model = PyTorchModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_training_compiler_version, + py_version="py38", + sagemaker_session=sagemaker_session, + ) + + pt_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": pytorch_training_compiler_version, + } + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) From d779d1b8296242eb15637e85272a1a50a7ee897b Mon Sep 17 00:00:00 2001 From: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> Date: Tue, 6 Dec 2022 16:37:16 -0800 Subject: [PATCH 16/58] feature: Add Neo image uri config for Pytorch 1.12 (#3507) --- .../image_uri_config/neo-pytorch.json | 36 ++++++++++++++++++- tests/data/pytorch_neo/code/inference.py | 4 +-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json index bd15a6450e..c46dd3de5d 100644 --- a/src/sagemaker/image_uri_config/neo-pytorch.json +++ b/src/sagemaker/image_uri_config/neo-pytorch.json @@ -11,7 +11,9 @@ "1.7.0": "1.7", "1.7.1": "1.7", "1.8.0": "1.8", - "1.8.1": "1.8" + "1.8.1": "1.8", + "1.12.0": "1.12", + "1.12.1": "1.12" }, "versions": { "1.4": { @@ -173,6 +175,38 @@ "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" + }, + "1.12": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "774647643957", + "ap-east-1": "110948597952", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-northeast-3": "925152966179", + "ap-south-1": "763008648453", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "ca-central-1": "464438896020", + "cn-north-1": "472730292857", + "cn-northwest-1": "474822919863", + "eu-central-1": "746233611703", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "me-south-1": "836785723513", + "sa-east-1": "756306329178", + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-gov-west-1": "263933020539", + "us-iso-east-1": "167761179201", + "us-isob-east-1": "406031935815", + "us-west-1": "710691900526", + "us-west-2": "301217895009" + }, + "repository": "sagemaker-inference-pytorch" } } } diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 5b89c2bebc..79fe66d716 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "model.pth" - model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) + # The compiled model is saved as "model.pth" or "model.pt" + model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl") From 83327fb9ef5eb5f44c9fd3f8925c7791576c9a37 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:15 +0000 Subject: [PATCH 17/58] prepare release v2.120.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8b3155231..71894ff29d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.120.0 (2022-12-07) + +### Features + + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + +### Bug Fixes and Other Changes + + * support idempotency for framework and spark processors + ## v2.119.0 (2022-12-03) ### Features diff --git a/VERSION b/VERSION index dda4128cf2..7de9d18b4e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.1.dev0 +2.120.0 From 5bffb04b78e8cd6422654008511aa61ca6f66efb Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:17 +0000 Subject: [PATCH 18/58] update development version to v2.120.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7de9d18b4e..73c4cd6968 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.0 +2.120.1.dev0 From b828396c55082bc5f06092be41555729d775874a Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Wed, 7 Dec 2022 20:58:37 +0530 Subject: [PATCH 19/58] feature: Algorithms Region Expansion OSU/DXB (#3508) Co-authored-by: Malav Shastri --- .../image_uri_config/blazingtext.json | 2 ++ .../factorization-machines.json | 2 ++ .../image_uri_config/forecasting-deepar.json | 2 ++ .../image-classification.json | 2 ++ .../image_uri_config/ipinsights.json | 2 ++ src/sagemaker/image_uri_config/kmeans.json | 2 ++ src/sagemaker/image_uri_config/knn.json | 2 ++ .../image_uri_config/linear-learner.json | 2 ++ src/sagemaker/image_uri_config/ntm.json | 2 ++ .../image_uri_config/object-detection.json | 2 ++ .../image_uri_config/object2vec.json | 2 ++ src/sagemaker/image_uri_config/pca.json | 2 ++ .../image_uri_config/randomcutforest.json | 2 ++ .../semantic-segmentation.json | 2 ++ src/sagemaker/image_uri_config/seq2seq.json | 2 ++ src/sagemaker/image_uri_config/sklearn.json | 14 ++++++++ src/sagemaker/image_uri_config/xgboost.json | 36 +++++++++++++++++++ tests/unit/sagemaker/image_uris/test_algos.py | 4 +++ .../unit/sagemaker/image_uris/test_sklearn.py | 2 ++ .../unit/sagemaker/image_uris/test_xgboost.py | 4 +++ 20 files changed, 90 insertions(+) diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index c588d65c73..ae4295c59a 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index 0f9930357f..8fb1895707 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 1acc96ed3e..e9beb7acb6 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "522234722520", "us-east-2": "566113047672", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "156387875391" diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 44ccb3f08d..61e037da08 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 4e56c149dc..cf3c70194f 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 952724ce11..e8e947f094 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 79b239966d..89e8ef6224 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index bb027284ab..606edd3791 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index 115264b346..16f9565405 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index 6a7ba03695..67b60fe587 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 39614d1273..b166cc96ff 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index 5f87d8528c..11982e2197 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index ae7a3574be..15dc84dfc5 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index 866dd606b4..f49bc43109 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index bb3daf93b6..87810ad09d 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index 7961fde282..4d093f5f62 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -24,10 +24,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -57,10 +59,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -90,10 +94,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -127,10 +133,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -160,10 +168,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -193,10 +203,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -230,10 +242,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index a809083c4a..946e78ecc4 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -25,10 +25,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -58,10 +60,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -91,10 +95,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -124,10 +130,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -155,10 +163,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -186,10 +196,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -217,10 +229,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -248,10 +262,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -286,10 +302,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -319,10 +337,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -352,10 +372,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -385,10 +407,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -416,10 +440,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -447,10 +473,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -478,10 +506,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -509,10 +539,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -544,10 +576,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -575,10 +609,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 454d375b4b..443727094a 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -68,10 +68,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107", @@ -155,10 +157,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index d0fcbdb300..8563753e8c 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -37,10 +37,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 78ab7e10ee..4d0f9f1dc3 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -35,10 +35,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", @@ -67,10 +69,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", From 357f73226c9c5fe651ea74169cafe585e1092ad0 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Wed, 7 Dec 2022 10:36:33 -0800 Subject: [PATCH 20/58] fix: Add constraints file for apache-airflow (#3510) --- requirements/extras/test_requirements.txt | 1 + tox.ini | 2 ++ 2 files changed, 3 insertions(+) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index b52f394bd0..fe93fd4d0e 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -11,6 +11,7 @@ contextlib2==21.6.0 awslogs==0.14.0 black==22.3.0 stopit==1.1.2 +# Update tox.ini to have correct version of airflow constraints file apache-airflow==2.4.1 apache-airflow-providers-amazon==4.0.0 attrs==22.1.0 diff --git a/tox.ini b/tox.ini index 2d5fdf0b40..3a398ca51d 100644 --- a/tox.ini +++ b/tox.ini @@ -73,6 +73,8 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" + pip install 'apache-airflow==2.4.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.4.1/constraints-3.10.txt" + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test] From a28d1dd129ecceb612d5e8927b6be72937711722 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Wed, 7 Dec 2022 19:14:12 -0800 Subject: [PATCH 21/58] fix: FrameworkProcessor S3 uploads (#3493) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 47 +++- .../data/pipeline/test_source_dir/script_1.py | 11 + .../data/pipeline/test_source_dir/script_2.py | 9 + .../pipeline/test_source_dir_2/script_2.py | 9 + .../workflow/test_processing_steps.py | 249 +++++++++++++++++- .../integ/sagemaker/workflow/test_workflow.py | 8 +- 6 files changed, 322 insertions(+), 11 deletions(-) create mode 100644 tests/data/pipeline/test_source_dir/script_1.py create mode 100644 tests/data/pipeline/test_source_dir/script_2.py create mode 100644 tests/data/pipeline/test_source_dir_2/script_2.py diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 81e3d34b1d..01d4361197 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1741,13 +1741,7 @@ def _pack_and_upload_code( raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") script = estimator.uploaded_code.script_name - s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(script), - desired_s3_uri=entrypoint_s3_uri, - kms_key=kms_key, - sagemaker_session=self.sagemaker_session, - ) - logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + s3_runproc_sh = self._create_and_upload_runproc(script, kms_key, entrypoint_s3_uri) return s3_runproc_sh, inputs, job_name @@ -1857,3 +1851,42 @@ def _set_entrypoint(self, command, user_script_name): ) ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] + + def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + """Create runproc shell script and upload to S3 bucket. + + If leveraging a pipeline session with optimized S3 artifact paths, + the runproc.sh file is hashed and uploaded to a separate S3 location. + + + Args: + user_script (str): Relative path to ```code``` in the source bundle + - e.g. 'process.py'. + kms_key (str): THe kms key used for encryption. + entrypoint_s3_uri (str): The S3 upload path for the runproc script. + """ + from sagemaker.workflow.utilities import _pipeline_config, hash_object + + if _pipeline_config and _pipeline_config.pipeline_name: + runproc_file_str = self._generate_framework_script(user_script) + runproc_file_hash = hash_object(runproc_file_str) + s3_uri = ( + f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" + f"code/{runproc_file_hash}/runproc.sh" + ) + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + runproc_file_str, + desired_s3_uri=s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + else: + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + self._generate_framework_script(user_script), + desired_s3_uri=entrypoint_s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + + return s3_runproc_sh diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py new file mode 100644 index 0000000000..4a427b1898 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_1.py @@ -0,0 +1,11 @@ +""" +Integ test file script_1.py +""" +import pathlib + +if __name__ == "__main__": + + print("writing file to /opt/ml/processing/test/test.py...") + pathlib.Path("/opt/ml/processing/test").mkdir(parents=True, exist_ok=True) + with open("/opt/ml/processing/test/test.py", "w") as f: + f.write('print("test...")') diff --git a/tests/data/pipeline/test_source_dir/script_2.py b/tests/data/pipeline/test_source_dir/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/data/pipeline/test_source_dir_2/script_2.py b/tests/data/pipeline/test_source_dir_2/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir_2/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py index 781bce85a7..238eff6123 100644 --- a/tests/integ/sagemaker/workflow/test_processing_steps.py +++ b/tests/integ/sagemaker/workflow/test_processing_steps.py @@ -17,15 +17,18 @@ import re import subprocess from datetime import datetime +from pathlib import Path import pytest from botocore.exceptions import WaiterError +from sagemaker.workflow.utilities import hash_files_or_dirs, hash_object from sagemaker import image_uris, get_execution_role, utils from sagemaker.dataset_definition import DatasetDefinition, AthenaDatasetDefinition -from sagemaker.processing import ProcessingInput, ProcessingOutput -from sagemaker.s3 import S3Uploader -from sagemaker.sklearn import SKLearnProcessor +from sagemaker.processing import ProcessingInput, ProcessingOutput, FrameworkProcessor +from sagemaker.s3 import S3Uploader, S3Downloader +from sagemaker.sklearn import SKLearnProcessor, SKLearn +from sagemaker.tensorflow import TensorFlow from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import ( @@ -379,6 +382,203 @@ def test_one_step_framework_processing_pipeline( pass +def test_multi_step_framework_processing_pipeline_same_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR = "/pipeline/test_source_dir" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + framework_processor_sk = FrameworkProcessor( + framework_version="1.0-1", + instance_type="ml.m5.xlarge", + instance_count=1, + base_job_name="my-job", + role=role, + estimator_cls=SKLearn, + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_sk.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_sk, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR, + "script_2.py", + ) + + # the same local source_dirs should have the same s3 paths + assert source_dir_1_s3_uri == source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_multi_step_framework_processing_pipeline_different_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR_1 = "/pipeline/test_source_dir" + SOURCE_DIR_2 = "/pipeline/test_source_dir_2" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR_1, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_tf.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR_2, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR_1, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR_2, + "script_2.py", + ) + + # different local source_dirs should have different s3 paths + assert source_dir_1_s3_uri != source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_one_step_pyspark_processing_pipeline( sagemaker_session, role, @@ -796,3 +996,46 @@ def test_two_processing_job_depends_on( pipeline.delete() except Exception: pass + + +def _verify_code_artifacts_of_framework_processing_step( + pipeline_session, processor, bucket, pipeline_name, step_definition, source_dir, entry_point +): + + source_dir_s3_uri = ( + f"s3://{bucket}/{pipeline_name}" f"/code/{hash_files_or_dirs([f'{DATA_DIR}/{source_dir}'])}" + ) + + # verify runproc.sh prefix is different from code artifact prefix + runprocs = [] + for input_obj in step_definition["Arguments"]["ProcessingInputs"]: + if input_obj["InputName"] == "entrypoint": + s3_uri = input_obj["S3Input"]["S3Uri"] + runprocs.append(s3_uri) + + assert Path(s3_uri).parent != source_dir_s3_uri + + # verify only one entrypoint generated per step + assert len(runprocs) == 1 + + expected_source_dir_tar = ( + f"{pipeline_name}" + f"/code/{hash_files_or_dirs([DATA_DIR + '/pipeline/test_source_dir'])}/sourcedir.tar.gz" + ) + + step_script = processor._generate_framework_script(entry_point) + expected_step_artifact = f"{pipeline_name}/code/{hash_object(step_script)}/runproc.sh" + + expected_prefix = f"{pipeline_name}/code" + s3_code_objects = pipeline_session.list_s3_files(bucket=bucket, key_prefix=expected_prefix) + + # verify all distinct artifacts were uploaded + assert expected_source_dir_tar in s3_code_objects + assert expected_step_artifact in s3_code_objects + + # verify runprocs contain the correct commands + step_runproc = S3Downloader.read_file( + f"s3://{bucket}/{expected_step_artifact}", pipeline_session + ) + assert f"python {entry_point}" in step_runproc + return source_dir, expected_step_artifact diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 634ef752d6..44f4e2d26e 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1168,7 +1168,13 @@ def walk(): def test_caching_behavior( - pipeline_session, role, cpu_instance_type, pipeline_name, script_dir, athena_dataset_definition + pipeline_session, + role, + cpu_instance_type, + pipeline_name, + script_dir, + athena_dataset_definition, + region_name, ): default_bucket = pipeline_session.default_bucket() data_path = os.path.join(DATA_DIR, "workflow") From 11d24754b0a8228893f6663ac1ca5048b8a6e794 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:54 +0000 Subject: [PATCH 22/58] prepare release v2.121.0 --- CHANGELOG.md | 11 +++++++++++ VERSION | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71894ff29d..29dad5f19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v2.121.0 (2022-12-08) + +### Features + + * Algorithms Region Expansion OSU/DXB + +### Bug Fixes and Other Changes + + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + ## v2.120.0 (2022-12-07) ### Features diff --git a/VERSION b/VERSION index 73c4cd6968..7f1e14b5a9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.1.dev0 +2.121.0 From 24171b5efcb9c528f159334d6252835ef10bbcb2 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:55 +0000 Subject: [PATCH 23/58] update development version to v2.121.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7f1e14b5a9..28b52ee8d5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.0 +2.121.1.dev0 From d5847d5ebad840c5f47204742302d91064904be8 Mon Sep 17 00:00:00 2001 From: Loki Date: Fri, 9 Dec 2022 03:10:14 +0530 Subject: [PATCH 24/58] Fix: Differentiate SageMaker Training Compiler's PT DLCs from base PT DLC (#3515) --- src/sagemaker/image_uri_config/pytorch-training-compiler.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json index 892ff4237d..fd7df875a3 100644 --- a/src/sagemaker/image_uri_config/pytorch-training-compiler.json +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -34,7 +34,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "pytorch-training" + "repository": "pytorch-trcomp-training" } } } From 3f6ea884a564090f826fab46270429db553c7b3b Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:17:44 -0500 Subject: [PATCH 25/58] fix: Fix failing jumpstart cache unit tests (#3514) --- setup.py | 2 +- src/sagemaker/jumpstart/cache.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4327045760..f366b147b8 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging>=20.0", + "packaging==20.9", "pandas", "pathos", "schema", diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 202edff9ad..db607770a7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -20,7 +20,7 @@ import boto3 import botocore from packaging.version import Version -from packaging.specifiers import SpecifierSet +from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -371,7 +371,10 @@ def _select_version( return None return str(max(available_versions)) - spec = SpecifierSet(f"=={semantic_version_str}") + try: + spec = SpecifierSet(f"=={semantic_version_str}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {semantic_version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None From 4570aa6078e75ba0d259f8196891b7856790a435 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Thu, 8 Dec 2022 19:00:48 -0800 Subject: [PATCH 26/58] fix: Pop out ModelPackageName from pipeline definition (#3472) Co-authored-by: Dewen Qi --- src/sagemaker/workflow/_utils.py | 12 ++ .../sagemaker/workflow/test_model_steps.py | 1 + tests/unit/sagemaker/workflow/conftest.py | 75 +++++++++ .../sagemaker/workflow/test_model_step.py | 147 +++++++----------- tests/unit/sagemaker/workflow/test_utils.py | 54 +------ 5 files changed, 150 insertions(+), 139 deletions(-) create mode 100644 tests/unit/sagemaker/workflow/conftest.py diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8ba65f1eee..cdef9537c1 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -13,6 +13,7 @@ """Scrapper utilities to support repacking of models.""" from __future__ import absolute_import +import logging import os import shutil import tarfile @@ -37,6 +38,8 @@ if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection +logger = logging.getLogger(__name__) + FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" @@ -479,10 +482,19 @@ def arguments(self) -> RequestType: request_dict = get_create_model_package_request(**model_package_args) # these are not available in the workflow service and will cause rejection + warn_msg_template = ( + "Popping out '%s' from the pipeline definition " + "since it will be overridden in pipeline execution time." + ) if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") + logger.warning(warn_msg_template, "CertifyForMarketplace") if "Description" in request_dict: request_dict.pop("Description") + logger.warning(warn_msg_template, "Description") + if "ModelPackageName" in request_dict: + request_dict.pop("ModelPackageName") + logger.warning(warn_msg_template, "ModelPackageName") return request_dict diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 31c518b100..f25723c440 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen inference_instances=["ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], description="test-description", + model_package_name="model-pkg-name-will-be-popped-out", ) step_model_regis = ModelStep( name="pytorch-register-model", diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py new file mode 100644 index 0000000000..9ea3d0bcac --- /dev/null +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock, PropertyMock + +import pytest + +from sagemaker import Session +from sagemaker.workflow.pipeline_context import PipelineSession + +REGION = "us-west-2" +BUCKET = "my-bucket" +ROLE = "DummyRole" +IMAGE_URI = "fakeimage" + + +@pytest.fixture(scope="module") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="module") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="module") +def pipeline_session(boto_session, client): + return PipelineSession( + boto_session=boto_session, + sagemaker_client=client, + default_bucket=BUCKET, + ) + + +@pytest.fixture(scope="module") +def sagemaker_session(boto_session, client): + return Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=BUCKET, + ) diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 080e70ca62..2216299d3b 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -15,7 +15,7 @@ import json import os -from mock import Mock, PropertyMock, patch +from mock import patch import pytest @@ -43,7 +43,6 @@ ) from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.pipeline import Pipeline, PipelineGraph -from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.retry import ( StepRetryPolicy, StepExceptionTypeEnum, @@ -55,11 +54,9 @@ from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE _IMAGE_URI = "fakeimage" -_REGION = "us-west-2" -_BUCKET = "my-bucket" -_ROLE = "DummyRole" _INSTANCE_TYPE = "ml.m4.xlarge" _SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper() @@ -69,60 +66,10 @@ _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") _REPACK_OUTPUT_KEY_PREFIX = "code-output" -_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" +_MODEL_CODE_LOCATION = f"s3://{BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" _MODEL_CODE_LOCATION_TRAILING_SLASH = _MODEL_CODE_LOCATION + "/" -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def boto_session(client): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - session_mock.client.return_value = client - - return session_mock - - -@pytest.fixture -def pipeline_session(boto_session, client): - return PipelineSession( - boto_session=boto_session, - sagemaker_client=client, - default_bucket=_BUCKET, - ) - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) - - @pytest.fixture def model_data_param(): return ParameterString(name="ModelData", default_value="s3://my-bucket/file") @@ -137,7 +84,7 @@ def model(pipeline_session, model_data_param): sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) @@ -322,13 +269,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_ sparkml_model = SparkMLModel( name="MySparkMLModel", model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) # The model need to runtime repack ppl_model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = ppl_model.create( instance_type="c4.4xlarge", @@ -417,7 +364,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat # The model no need to runtime repack, since source_dir is missing sparkml_model = SparkMLModel( model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", @@ -429,11 +376,11 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, env={"k": "v"}, ) model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = model.register( content_types=["text/csv"], @@ -516,7 +463,7 @@ def test_register_model_without_repack(pipeline_session): model_data=model_data, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", sagemaker_session=pipeline_session, - role=_ROLE, + role=ROLE, ) step_args = model.register( content_types=["text/csv"], @@ -547,7 +494,7 @@ def test_register_model_without_repack(pipeline_session): assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert ( containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] - == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz" + == f"s3://{BUCKET}/{model_name}/sourcedir.tar.gz" ) adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []}) @@ -560,11 +507,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): model = Model( name=model_name, image_uri=_IMAGE_URI, - model_data=f"s3://{_BUCKET}/model.tar.gz", + model_data=f"s3://{BUCKET}/model.tar.gz", sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) step_args = model.create( instance_type="c4.4xlarge", @@ -582,7 +529,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): arguments = step_dsl_list[0]["Arguments"] assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI assert ( - arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{_BUCKET}/{model_name}/model.tar.gz" + arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz" ) assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME @@ -700,7 +647,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=True, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -713,7 +660,7 @@ def test_conditional_model_create_and_regis( framework_version="1.11.0", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=False, ), 1, @@ -724,7 +671,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.5.0", code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -736,7 +683,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.2.0", ), 1, @@ -747,7 +694,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 2, ), @@ -757,7 +704,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), 2, @@ -768,7 +715,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 1, ), @@ -789,7 +736,7 @@ def assert_test_result(steps: list): ) else: assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( - f"s3://{_BUCKET}/{model.name}" + f"s3://{BUCKET}/{model.name}" ) model, expected_step_num = test_input @@ -828,7 +775,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), enable_network_isolation=True, ), @@ -845,7 +792,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), ), { @@ -861,7 +808,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=None, ), { @@ -876,9 +823,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), ), { @@ -893,9 +839,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, ), { "expected_step_num": 1, @@ -941,7 +886,7 @@ def test_request_compare_of_register_model_under_different_sessions( _verify_register_model_container_definition(regis_step_arg, expect, dict) # Get create model package request under Session - model.model_data = f"s3://{_BUCKET}" + model.model_data = f"s3://{BUCKET}" model.sagemaker_session = sagemaker_session with patch.object( Session, "_intercept_create_request", return_value=dict(ModelPackageArn="arn:aws") @@ -996,7 +941,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): model_data=lambda_step.properties.Outputs["model_artifact"], sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ) step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) @@ -1031,7 +976,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( Processor( image_uri=_IMAGE_URI, - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, ), @@ -1052,7 +997,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( HyperparameterTuner( estimator=Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1064,7 +1009,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ), ( Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1128,3 +1073,31 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model): ), ) assert "SageMakerJobStepRetryPolicy is not allowed for a create/registe" in str(error.value) + + +def test_register_model_step_with_model_package_name(pipeline_session): + model = Model( + name="MyModel", + image_uri="my-image", + model_data="s3://", + sagemaker_session=pipeline_session, + ) + step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_name="model-pkg-name-will-be-popped-out", + ) + regis_model_step = ModelStep( + name="MyModelStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[regis_model_step], + sagemaker_session=pipeline_session, + ) + steps = json.loads(pipeline.definition())["Steps"] + assert len(steps) == 1 + assert "ModelPackageName" not in steps[0]["Arguments"] diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index dcbf5a6421..c8d86c5866 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -18,12 +18,6 @@ import tempfile import pytest -import sagemaker - -from mock import ( - Mock, - PropertyMock, -) from sagemaker.estimator import Estimator from sagemaker.workflow._utils import ( @@ -35,51 +29,7 @@ from sagemaker.workflow.properties import Properties from tests.unit.test_utils import FakeS3, list_tar_files from tests.unit import DATA_DIR - -REGION = "us-west-2" -BUCKET = "my-bucket" -IMAGE_URI = "fakeimage" -ROLE = "DummyRole" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) +from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET @pytest.fixture @@ -171,7 +121,7 @@ def test_repack_model_step(estimator): } -def test_repack_model_step_with_invalid_input(): +def test_register_model_step_with_invalid_input(): # without both step_args and any of the old required arguments with pytest.raises(ValueError) as error: _RegisterModelStep( From 959ea1a485db702f361ddebda2e80779bfd20e43 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:46 +0000 Subject: [PATCH 27/58] prepare release v2.121.1 --- CHANGELOG.md | 7 +++++++ VERSION | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29dad5f19f..472a25feb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v2.121.1 (2022-12-09) + +### Bug Fixes and Other Changes + + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + ## v2.121.0 (2022-12-08) ### Features diff --git a/VERSION b/VERSION index 28b52ee8d5..f73c7f057e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1.dev0 +2.121.1 From b2e8b66016c09a3898123725bf1c01d1a87b05d0 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:47 +0000 Subject: [PATCH 28/58] update development version to v2.121.2.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index f73c7f057e..d866b235cc 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1 +2.121.2.dev0 From 355975d4d2d45088eeb13681f8d99e48a00909c9 Mon Sep 17 00:00:00 2001 From: amzn-choeric <105388439+amzn-choeric@users.noreply.github.com> Date: Fri, 9 Dec 2022 13:53:28 -0500 Subject: [PATCH 29/58] fix: Skip Bad Transform Test (#3521) --- tests/integ/test_inference_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 53d966fe9b..a26d8c9101 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -50,6 +50,7 @@ ) +@pytest.mark.skip(reason="Test has likely been failing for a while. Suspected bad XGB model.") def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type): sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"), From fadc817c7557f5fea5e414d51b500a6b7cd02065 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Fri, 9 Dec 2022 12:07:32 -0800 Subject: [PATCH 30/58] fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524) From c5fc93feea798df1713db6707737a2f24738c4c7 Mon Sep 17 00:00:00 2001 From: hballuru <113142824+hballuru@users.noreply.github.com> Date: Fri, 9 Dec 2022 16:36:12 -0600 Subject: [PATCH 31/58] change: Update for Tensorflow Serving 2.11 inference DLCs (#3509) --- .../image_uri_config/tensorflow.json | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index a0f2bba014..aaca927ba4 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -303,7 +303,8 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0" + "2.10": "2.10.0", + "2.11": "2.11.0" }, "versions": { "1.10.0": { @@ -1611,6 +1612,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1618,8 +1620,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1642,6 +1646,41 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.11.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1649,8 +1688,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", From ec8da98a9a7cae848e8bf1af06bdaaabd1ebb382 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:18:58 +0000 Subject: [PATCH 32/58] prepare release v2.121.2 --- CHANGELOG.md | 8 ++++++++ VERSION | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 472a25feb8..8b66e85f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.121.2 (2022-12-12) + +### Bug Fixes and Other Changes + + * Update for Tensorflow Serving 2.11 inference DLCs + * Revert "fix: type hint of PySparkProcessor __init__" + * Skip Bad Transform Test + ## v2.121.1 (2022-12-09) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index d866b235cc..3b02379cd3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2.dev0 +2.121.2 From 03521222d324ed752174038309828ed8183c5aea Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:19:00 +0000 Subject: [PATCH 33/58] update development version to v2.121.3.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 3b02379cd3..8fde5e282f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2 +2.121.3.dev0 From d6c021404586d4df601a6115add87fcbf75b6d65 Mon Sep 17 00:00:00 2001 From: Kristopher Siman Date: Mon, 12 Dec 2022 17:21:49 -0500 Subject: [PATCH 34/58] feature: Add OSU region to frameworks for DLC (#3532) --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../image_uri_config/huggingface-neuron.json | 1 + .../image_uri_config/huggingface.json | 31 ++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ .../image_uri_config/pytorch-neuron.json | 1 + src/sagemaker/image_uri_config/pytorch.json | 31 ++++++++ .../image_uri_config/tensorflow.json | 70 +++++++++++++++++++ 7 files changed, 159 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3a9f02142c..590b6e5f82 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -30,6 +30,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -61,6 +62,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -92,6 +94,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -123,6 +126,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -154,6 +158,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -185,6 +190,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -230,6 +236,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -267,6 +274,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -304,6 +312,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -341,6 +350,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -378,6 +388,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -415,6 +426,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 47d6dbd1dc..980dceed17 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -33,6 +33,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 5b98fc0d02..a0caa59a55 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -42,6 +42,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -75,6 +76,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -114,6 +116,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -147,6 +150,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -188,6 +192,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -222,6 +227,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -256,6 +262,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -290,6 +297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -332,6 +340,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -366,6 +375,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -400,6 +410,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -434,6 +445,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -474,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -508,6 +521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -548,6 +562,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -582,6 +597,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -622,6 +638,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -656,6 +673,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -712,6 +730,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -749,6 +768,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -786,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -831,6 +852,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -868,6 +890,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -905,6 +928,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -942,6 +966,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -985,6 +1010,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1022,6 +1048,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1065,6 +1092,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1102,6 +1130,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1145,6 +1174,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1182,6 +1212,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 8d8733e480..588a03a76e 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -249,6 +249,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -282,6 +283,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -315,6 +317,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -348,6 +351,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -381,6 +385,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -644,6 +649,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -680,6 +686,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -716,6 +723,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -752,6 +760,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -788,6 +797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +907,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -933,6 +944,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -969,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch-neuron.json b/src/sagemaker/image_uri_config/pytorch-neuron.json index b116a8a36b..5b29406955 100644 --- a/src/sagemaker/image_uri_config/pytorch-neuron.json +++ b/src/sagemaker/image_uri_config/pytorch-neuron.json @@ -28,6 +28,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 18a382e591..85681a3423 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -208,6 +208,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -247,6 +248,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -285,6 +287,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -323,6 +326,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -362,6 +366,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -401,6 +406,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -440,6 +446,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -479,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -517,6 +525,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -555,6 +564,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -593,6 +603,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -631,6 +642,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -669,6 +681,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -707,6 +720,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -744,6 +758,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -791,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -951,6 +967,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -987,6 +1004,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1023,6 +1041,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1058,6 +1077,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1094,6 +1114,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1130,6 +1151,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1166,6 +1188,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1202,6 +1225,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1237,6 +1261,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1272,6 +1297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1307,6 +1333,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1342,6 +1369,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1377,6 +1405,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1412,6 +1441,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1446,6 +1476,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index aaca927ba4..a900aa4fe5 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -161,6 +161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -196,6 +197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -231,6 +233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -266,6 +269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -425,6 +429,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -460,6 +465,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -495,6 +501,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -530,6 +537,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -565,6 +573,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -600,6 +609,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -635,6 +645,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -862,6 +873,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +909,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -932,6 +945,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -967,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1002,6 +1017,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1037,6 +1053,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1072,6 +1089,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1107,6 +1125,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1142,6 +1161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1177,6 +1197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1212,6 +1233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1247,6 +1269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1282,6 +1305,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1317,6 +1341,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1352,6 +1377,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1387,6 +1413,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1422,6 +1449,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1457,6 +1485,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1492,6 +1521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1527,6 +1557,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1562,6 +1593,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1597,6 +1629,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1631,6 +1664,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1665,6 +1699,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1699,6 +1734,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1746,6 +1782,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1940,6 +1977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1977,6 +2015,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2013,6 +2052,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2050,6 +2090,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2087,6 +2128,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2124,6 +2166,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2161,6 +2204,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2389,6 +2433,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2425,6 +2470,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2461,6 +2507,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2496,6 +2543,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2531,6 +2579,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2567,6 +2616,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2603,6 +2653,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2638,6 +2689,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2673,6 +2725,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2708,6 +2761,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2743,6 +2797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2778,6 +2833,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2813,6 +2869,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2848,6 +2905,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2883,6 +2941,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2918,6 +2977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2953,6 +3013,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2988,6 +3049,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3023,6 +3085,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3058,6 +3121,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3093,6 +3157,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3128,6 +3193,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3163,6 +3229,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3198,6 +3265,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3233,6 +3301,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3267,6 +3336,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", From 5af4feb57d950358dcf5dd15aad7f7d59ae11b31 Mon Sep 17 00:00:00 2001 From: Xiaoguang Chen <68292680+xgchena@users.noreply.github.com> Date: Mon, 12 Dec 2022 15:59:33 -0800 Subject: [PATCH 35/58] fix: Remove content type image/jpg from analysis configuration schema (#3530) Currently the analysis configuration schema of SageMaker Clarify API allows the content_type configuration "image/jpeg" and "image/jpg", but the service side validation only accepts the former which is the registered MIME type for JPEG (see rfc3745 and JPEG specification). The commit removes the latter from the schema to avoid confusion and enable early API validation. --- src/sagemaker/clarify.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4765630ce8..f082679401 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -282,7 +282,6 @@ "text/csv", "application/jsonlines", "image/jpeg", - "image/jpg", "image/png", "application/x-npy", ), From 438984754a8f44b34d70154197a3bbeb0272f052 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Mon, 12 Dec 2022 22:37:35 -0500 Subject: [PATCH 36/58] fix: unpin packaging version (#3533) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f366b147b8..4327045760 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging==20.9", + "packaging>=20.0", "pandas", "pathos", "schema", From a3efddf6d6a4e89861f2ae1eca9d7fd7712a691b Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Tue, 13 Dec 2022 20:45:06 +0100 Subject: [PATCH 37/58] fix: the Hyperband support fix for the HPO (#3516) Co-authored-by: Anton Repushko --- src/sagemaker/session.py | 9 +++++++ src/sagemaker/tuner.py | 14 +++++------ tests/unit/test_session.py | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..3fc4fc1256 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2121,6 +2121,7 @@ def tune( # noqa: C901 stop_condition, tags, warm_start_config, + strategy_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, @@ -2136,6 +2137,8 @@ def tune( # noqa: C901 Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. + strategy_config (dict): A configuration for the hyperparameter tuning + job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. @@ -2220,6 +2223,7 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, @@ -2375,6 +2379,7 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2392,6 +2397,8 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + strategy_config (dict): A configuration for the hyperparameter tuning job optimisation + strategy. Returns: A dictionary of tuning job configuration. For format details, please refer to @@ -2415,6 +2422,8 @@ def _map_tuning_config( if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges + if strategy_config is not None: + tuning_config["StrategyConfig"] = strategy_config return tuning_config @classmethod diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 52b9d81d0d..9a694cbec9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config): Returns: sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of - HyperbandStrategyConfig containing the max_resource and min_resource provided as part of - ``hyperband_strategy_config``. + ``HyperbandStrategyConfig`` containing the max_resource + and min_resource provided as part of ``hyperband_strategy_config``. """ return cls( min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE], @@ -306,7 +306,7 @@ def to_input_req(self): Returns: dict: Containing the "MaxResource" and - "MinResource" as the first class fields. + "MinResource" as the first class fields. """ return { HYPERBAND_MIN_RESOURCE: self.min_resource, @@ -330,7 +330,7 @@ def __init__( Args: hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration - for the object that specifies the Hyperband strategy. + for the object that specifies the Hyperband strategy. This parameter is only supported for the Hyperband selection for Strategy within the HyperParameterTuningJobConfig. """ @@ -461,7 +461,7 @@ def __init__( ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start tuning job. strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter" - tuning job optimisation strategy. + tuning job optimisation strategy. early_stopping_type (str or PipelineVariable): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. @@ -1569,7 +1569,7 @@ def create( strategy (str): Strategy to be used for hyperparameter estimations (default: 'Bayesian'). strategy_config (dict): The configuration for a training job launched by a - hyperparameter tuning job. + hyperparameter tuning job. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). max_jobs (int): Maximum total number of training jobs to start for the hyperparameter @@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs): } if tuner.strategy_config is not None: - tuning_config["strategy_config"] = tuner.strategy_config + tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() if tuner.objective_metric_name is not None: tuning_config["objective_type"] = tuner.objective_type diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..bf81283177 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session): ], } +SAMPLE_HYPERBAND_STRATEGY_CONFIG = { + "HyperbandStrategyConfig": { + "MinResource": 1, + "MaxResource": 10, + } +} + @pytest.mark.parametrize( "warm_start_type, parents", @@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags): ) +def test_tune_with_strategy_config(sagemaker_session): + def assert_create_tuning_job_request(**kwrags): + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MinResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"] + ) + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MaxResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"] + ) + + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image_uri="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG, + ) + + def test_tune_with_encryption_flag(sagemaker_session): def assert_create_tuning_job_request(**kwrags): assert ( From bd96ec5c585217bdec31951d632247f4b0d9f91b Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Tue, 13 Dec 2022 16:06:08 -0800 Subject: [PATCH 38/58] feature: Feature Store dataset builder, delete_record, get_record, list_feature_group (#3534) Co-authored-by: Eric Zou Co-authored-by: Yiming Zou Co-authored-by: Brandon Chatham Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> --- .../feature_store/dataset_builder.py | 990 ++++++++++++++++++ src/sagemaker/feature_store/feature_group.py | 45 +- src/sagemaker/feature_store/feature_store.py | 130 +++ src/sagemaker/session.py | 94 +- tests/integ/test_feature_store.py | 400 +++++++ .../feature_store/test_dataset_builder.py | 612 +++++++++++ .../feature_store/test_feature_group.py | 580 ++++++++++ .../feature_store/test_feature_store.py | 687 ++---------- tests/unit/test_session.py | 29 + 9 files changed, 2979 insertions(+), 588 deletions(-) create mode 100644 src/sagemaker/feature_store/dataset_builder.py create mode 100644 src/sagemaker/feature_store/feature_store.py create mode 100644 tests/unit/sagemaker/feature_store/test_dataset_builder.py create mode 100644 tests/unit/sagemaker/feature_store/test_feature_group.py diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py new file mode 100644 index 0000000000..fc82997379 --- /dev/null +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -0,0 +1,990 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Dataset Builder + +A Dataset Builder is a builder class for generating a dataset by providing conditions. +""" +from __future__ import absolute_import + +import datetime +from enum import Enum +import os +from typing import Any, Dict, List, Tuple, Union + +import attr +import pandas as pd + +from sagemaker import Session, s3, utils +from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum + + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + + +@attr.s +class TableType(Enum): + """Enum of Table types. + + The data type of a table can be FeatureGroup or DataFrame. + """ + + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +@attr.s +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the sql join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_dentifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + """ + + features: List[str] = attr.ib() + included_feature_names: List[str] = attr.ib() + projected_feature_names: List[str] = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table_name: str = attr.ib() + record_identifier_feature_name: str = attr.ib() + event_time_identifier_feature: FeatureDefinition = attr.ib() + target_feature_name_in_base: str = attr.ib(default=None) + table_type: TableType = attr.ib(default=None) + + +def construct_feature_group_to_be_merged( + feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + Returns: + A FeatureGroupToBeMerged object. + + Raises: + ValueError: Invalid feature name(s) in included_feature_names. + """ + feature_group_metadata = feature_group.describe() + data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get( + "DataCatalogConfig", None + ) + if not data_catalog_config: + raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.") + + record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None) + feature_definitions = feature_group_metadata.get("FeatureDefinitions", []) + event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None) + event_time_identifier_feature_type = FeatureTypeEnum( + next( + filter( + lambda f: f.get("FeatureName", None) == event_time_identifier_feature_name, + feature_definitions, + ), + {}, + ).get("FeatureType", None) + ) + table_name = data_catalog_config.get("TableName", None) + database = data_catalog_config.get("Database", None) + disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False) + catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG + features = [feature.get("FeatureName", None) for feature in feature_definitions] + + for included_feature in included_feature_names or []: + if included_feature not in features: + raise ValueError( + f"Feature {included_feature} not found in FeatureGroup {feature_group.name}" + ) + if not included_feature_names: + included_feature_names = features + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_identifier_feature_name not in included_feature_names: + included_feature_names.append(record_identifier_feature_name) + if event_time_identifier_feature_name not in included_feature_names: + included_feature_names.append(event_time_identifier_feature_name) + return FeatureGroupToBeMerged( + features, + included_feature_names, + projected_feature_names, + catalog, + database, + table_name, + record_identifier_feature_name, + FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type), + target_feature_name_in_base, + TableType.FEATURE_GROUP, + ) + + +@attr.s +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + _kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing whether using point in time join + or not (default: False). + _include_duplicated_records (bool): A boolean representing whether including duplicated + records or not (default: False). + _include_deleted_records (bool): A boolean representing whether including deleted records or + not (default: False). + _number_of_recent_records (int): An int that how many records will be returned for each + record identifier (default: 1). + _number_of_records (int): An int that how many records will be returned (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in + dataset will be before it (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time + in dataset will be after it (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be before it (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session = attr.ib() + _base: Union[FeatureGroup, pd.DataFrame] = attr.ib() + _output_path: str = attr.ib() + _record_identifier_feature_name: str = attr.ib(default=None) + _event_time_identifier_feature_name: str = attr.ib(default=None) + _included_feature_names: List[str] = attr.ib(default=None) + _kms_key_id: str = attr.ib(default=None) + + _point_in_time_accurate_join: bool = attr.ib(init=False, default=False) + _include_duplicated_records: bool = attr.ib(init=False, default=False) + _include_deleted_records: bool = attr.ib(init=False, default=False) + _number_of_recent_records: int = attr.ib(init=False, default=None) + _number_of_records: int = attr.ib(init=False, default=None) + _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list) + _event_time_identifier_feature_type: FeatureTypeEnum = attr.ib(default=None) + + _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = { + "object": "STRING", + "int64": "INT", + "float64": "DOUBLE", + "bool": "BOOLEAN", + "datetime64[ns]": "TIMESTAMP", + } + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + ): + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base + ) + ) + return self + + def point_in_time_accurate_join(self): + """Set join type as point in time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self): + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self): + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int): + """Set number_of_recent_records field with provided input. + + Args: + number_of_recent_records (int): An int that how many recent records will be returned for + each record identifier. + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = number_of_recent_records + return self + + def with_number_of_records_from_query_results(self, number_of_records: int): + """Set number_of_records field with provided input. + + Args: + number_of_records (int): An int that how many records will be returned. + Returns: + This DatasetBuilder object. + """ + self._number_of_records = number_of_records + return self + + def as_of(self, timestamp: datetime.datetime): + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ): + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> Tuple[str, str]: + """Get query string and result in .csv format file + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + temp_id = utils.unique_name_from_base("dataframe-base") + local_file_name = f"{temp_id}.csv" + desired_s3_folder = f"{self._output_path}/{temp_id}" + self._base.to_csv(local_file_name, index=False, header=False) + s3.S3Uploader.upload( + local_path=local_file_name, + desired_s3_uri=desired_s3_folder, + sagemaker_session=self._sagemaker_session, + kms_key=self._kms_key_id, + ) + os.remove(local_file_name) + temp_table_name = f'dataframe_{temp_id.replace("-", "_")}' + self._create_temp_table(temp_table_name, desired_s3_folder) + base_features = list(self._base.columns) + event_time_identifier_feature_dtype = self._base[ + self._event_time_identifier_feature_name + ].dtypes + self._event_time_identifier_feature_type = ( + FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + str(event_time_identifier_feature_dtype), None + ) + ) + query_string = self._construct_query_string( + FeatureGroupToBeMerged( + base_features, + self._included_feature_names if self._included_feature_names else base_features, + self._included_feature_names if self._included_feature_names else base_features, + _DEFAULT_CATALOG, + _DEFAULT_DATABASE, + temp_table_name, + self._record_identifier_feature_name, + FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + None, + TableType.DATA_FRAME, + ) + ) + query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + # TODO: cleanup temp table, need more clarification, keep it for now + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + if isinstance(self._base, FeatureGroup): + base_feature_group = construct_feature_group_to_be_merged( + self._base, self._included_feature_names + ) + self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name + self._event_time_identifier_feature_name = ( + base_feature_group.event_time_identifier_feature.feature_name + ) + self._event_time_identifier_feature_type = ( + base_feature_group.event_time_identifier_feature.feature_type + ) + query_string = self._construct_query_string(base_feature_group) + query_result = self._run_query( + query_string, + base_feature_group.catalog, + base_feature_group.database, + ) + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> Tuple[pd.DataFrame, str]: + """Get query string and result in pandas.Dataframe + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + s3.S3Downloader.download( + s3_uri=csv_file, + local_path="./", + kms_key=self._kms_key_id, + sagemaker_session=self._sagemaker_session, + ) + local_file_name = csv_file.split("/")[-1] + df = pd.read_csv(local_file_name) + os.remove(local_file_name) + + local_metadata_file_name = local_file_name + ".metadata" + if os.path.exists(local_metadata_file_name): + os.remove(local_file_name + ".metadata") + + if "row_recent" in df: + df = df.drop("row_recent", axis="columns") + return df, query_string + + def _construct_event_time_conditions( + self, + table_name: str, + event_time_identifier_feature: FeatureDefinition, + ) -> List[str]: + """Internal method for constructing event time range sql range as string. + + Args: + table_name (str): name of the table. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + Returns: + The list of query strings. + """ + event_time_conditions = [] + timestamp_cast_function_name = "from_unixtime" + if event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._event_time_starting_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") >= ' + + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") <= ' + + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return event_time_conditions + + def _construct_write_time_condition( + self, + table_name: str, + ) -> str: + """Internal method for constructing write time condition. + + Args: + table_name (str): name of the table. + Returns: + string of write time condition. + """ + write_time_condition = ( + f'{table_name}."write_time" <= ' + f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', " + f"'yyyy-mm-dd hh24:mi:ss')" + ) + return write_time_condition + + def _construct_where_query_string( + self, + suffix: str, + event_time_identifier_feature: FeatureDefinition, + where_conditions: List[str], + ) -> str: + """Internal method for constructing SQL WHERE query string by parameters. + + Args: + suffix (str): A temp identifier of the FeatureGroup. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + where_conditions (List[str]): A list of strings representing existing where clauses. + Returns: + The WHERE query string. + + Raises: + ValueError: FeatureGroup not provided while using as_of(). Only found pandas.DataFrame. + """ + if self._number_of_recent_records: + if self._number_of_recent_records < 0: + raise ValueError( + "Please provide non-negative integer for number_of_recent_records." + ) + if self._number_of_records: + if self._number_of_records < 0: + raise ValueError("Please provide non-negative integer for number_of_records.") + if self._include_deleted_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_deleted_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._include_duplicated_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_duplicated_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._point_in_time_accurate_join: + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "point_in_time_accurate_join() this operation only works when there is " + "more than one feature group to join." + ) + if self._write_time_ending_timestamp: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "as_of() only works for FeatureGroup," " if there is no join operation." + ) + if isinstance(self._base, FeatureGroup): + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"table_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) == 0: + return "" + return "WHERE " + "\nAND ".join(where_conditions) + + def _construct_dedup_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing duplicate records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = "" + where_conditions = [] + where_conditions_str = "" + is_dedup_enabled = False + + if feature_group.table_type is TableType.FEATURE_GROUP: + is_dedup_enabled = True + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"origin_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) != 0: + where_conditions_str = "WHERE " + "\nAND ".join(where_conditions) + "\n" + + dedup_where_clause = f"WHERE dedup_row_{suffix} = 1\n" if is_dedup_enabled else "" + return ( + f"table_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS dedup_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + where_conditions_str + + ")\n" + + dedup_where_clause + + ")" + ) + + def _construct_deleted_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing deleted records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = f'ORDER BY origin_{suffix}."{event_time_feature_name}" DESC' + write_time_condition = "\n" + event_time_starting_condition = "" + event_time_ending_condition = "" + + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string += ( + f', origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + write_time_condition += " AND " + write_time_condition += self._construct_write_time_condition(f"origin_{suffix}") + write_time_condition += "\n" + + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + event_time_starting_condition = "AND " + event_time_conditions[0] + "\n" + event_time_ending_condition = "AND " + event_time_conditions[1] + "\n" + + return ( + f"deleted_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}"\n' + + rank_query_string + + f") AS deleted_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE is_deleted" + + write_time_condition + + event_time_starting_condition + + event_time_ending_condition + + ")\n" + + f"WHERE deleted_row_{suffix} = 1\n" + + ")" + ) + + def _construct_table_included_features( + self, feature_group: FeatureGroupToBeMerged, suffix: str + ) -> str: + """Internal method for constructing included features string of table. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object + which has the metadata. + suffix (str): A temp identifier of the table. + Returns: + The string that includes all feature to be included of table. + """ + + included_features = ", ".join( + [ + f'table_{suffix}."{include_feature_name}"' + for include_feature_name in feature_group.included_feature_names + ] + ) + return included_features + + def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The query string. + """ + included_features = self._construct_table_included_features(feature_group, suffix) + + # If base is a FeatureGroup then included_features_write_time will have a write_time column + # Or included_features_write_time is same as included_features + included_features_write_time = included_features + + if feature_group.table_type is TableType.FEATURE_GROUP: + included_features_write_time += f', table_{suffix}."write_time"' + record_feature_name = feature_group.record_identifier_feature_name + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included_features}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n' + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, ["NOT is_deleted"] + ) + ) + if feature_group.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + return ( + f"SELECT {included_features}\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE NOT is_deleted" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, + feature_group.event_time_identifier_feature, + [f"row_{suffix} = 1"], + ) + ) + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" > ' + + f'deleted_{suffix}."api_invocation_time")\n' + + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" = ' + + f'deleted_{suffix}."api_invocation_time" ' + + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + final_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + if self._include_duplicated_records: + final_query_string = ( + f"WITH {self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f'JOIN "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)},\n" + + f"{self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f"JOIN table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + return final_query_string + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata. + Returns: + The query string. + + Raises: + ValueError: target_feature_name_in_base is an invalid feature name. + """ + base_table_query_string = self._construct_table_query(base, "base") + query_string = f"WITH fg_base AS ({base_table_query_string})" + if len(self._feature_groups_to_be_merged) > 0: + with_subquery_string = "".join( + [ + f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})" + for i, feature_group in enumerate(self._feature_groups_to_be_merged) + ] + ) + query_string += with_subquery_string + + selected_features = "" + selected_features += ", ".join(map("fg_base.{0}".format, base.projected_feature_names)) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features += ", " + selected_features += ", ".join( + [ + f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + for feature_name in feature_group.projected_feature_names + ] + ) + + selected_features_final = "" + selected_features_final += ", ".join(base.projected_feature_names) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features_final += ", " + selected_features_final += ", ".join( + [ + '"{0}.{1}"'.format(feature_name, (i + 1)) + for feature_name in feature_group.projected_feature_names + ] + ) + + query_string += ( + f"\nSELECT {selected_features_final}\n" + + "FROM (\n" + + f"SELECT {selected_features}, row_number() OVER (\n" + + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + recent_record_where_clause = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_record_where_clause = f"WHERE row_recent <= {self._number_of_recent_records}" + + join_subquery_strings = [] + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + if not feature_group.target_feature_name_in_base: + feature_group.target_feature_name_in_base = self._record_identifier_feature_name + else: + if feature_group.target_feature_name_in_base not in base.features: + raise ValueError( + f"Feature {feature_group.target_feature_name_in_base} not found in base" + ) + query_string += ( + f', fg_{i}."{feature_group.event_time_identifier_feature.feature_name}" DESC' + ) + join_subquery_strings.append(self._construct_join_condition(feature_group, str(i))) + + query_string += ( + "\n) AS row_recent\n" + + "FROM fg_base" + + "".join(join_subquery_strings) + + "\n)\n" + + f"{recent_record_where_clause}" + ) + + if self._number_of_records is not None and self._number_of_records >= 0: + query_string += f"\nLIMIT {self._number_of_records}" + return query_string + + def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL JOIN query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The JOIN query string. + """ + join_condition_string = ( + f"\nJOIN fg_{suffix}\n" + + f'ON fg_base."{feature_group.target_feature_name_in_base}" = ' + + f'fg_{suffix}."{feature_group.record_identifier_feature_name}"' + ) + base_timestamp_cast_function_name = "from_unixtime" + if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING: + base_timestamp_cast_function_name = "from_iso8601_timestamp" + timestamp_cast_function_name = "from_unixtime" + if feature_group.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._point_in_time_accurate_join: + join_condition_string += ( + f"\nAND {base_timestamp_cast_function_name}(fg_base." + + f'"{self._event_time_identifier_feature_name}") >= ' + + f"{timestamp_cast_function_name}(fg_{suffix}." + + f'"{feature_group.event_time_identifier_feature.feature_name}")' + ) + return join_condition_string + + def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str): + """Internal method for creating a temp Athena table for the base pandas.Dataframe. + + Args: + temp_table_name (str): The Athena table name of base pandas.DataFrame. + desired_s3_folder (str): The S3 URI of the folder of the data. + """ + columns_string = ", ".join( + [self._construct_athena_table_column_string(column) for column in self._base.columns] + ) + serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query_string = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + f"WITH SERDEPROPERTIES ({serde_properties}) " + + f"LOCATION '{desired_s3_folder}';" + ) + self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + def _construct_athena_table_column_string(self, column: str) -> str: + """Internal method for constructing string of Athena column. + + Args: + column (str): The column name from pandas.Dataframe. + Returns: + The Athena column string. + + Raises: + RuntimeError: The type of pandas.Dataframe column is not support yet. + """ + dataframe_type = self._base[column].dtypes + if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys(): + raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.") + return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}" + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + """Internal method for execute Athena query, wait for query finish and get query result. + + Args: + query_string (str): The SQL query statements to be executed. + catalog (str): The name of the data catalog used in the query execution. + database (str): The name of the database used in the query execution. + Returns: + The query result. + + Raises: + RuntimeError: Athena query failed. + """ + query = self._sagemaker_session.start_query_execution( + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + query_id = query.get("QueryExecutionId", None) + self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id) + query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id) + query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None) + + if query_state != "SUCCEEDED": + raise RuntimeError(f"Failed to execute query {query_id}.") + return query_result diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index d486ab8a01..855e11488f 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -435,13 +435,14 @@ class FeatureGroup: "uint64", ] _FLOAT_TYPES = ["float_", "float16", "float32", "float64"] - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES } - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( {type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES} ) - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING _FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { FeatureTypeEnum.INTEGRAL.value: "INT", @@ -629,7 +630,7 @@ def load_feature_definitions( """ feature_definitions = [] for column in data_frame: - feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( str(data_frame[column].dtype), None ) if feature_type: @@ -644,6 +645,23 @@ def load_feature_definitions( self.feature_definitions = feature_definitions return self.feature_definitions + def get_record( + self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None + ) -> Sequence[Dict[str, str]]: + """Get a single record in a FeatureGroup + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + feature_names (Sequence[String]): + a list of Strings representing feature names. + """ + return self.sagemaker_session.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_group_name=self.name, + feature_names=feature_names, + ).get("Record") + def put_record(self, record: Sequence[FeatureValue]): """Put a single record in the FeatureGroup. @@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]): feature_group_name=self.name, record=[value.to_dict() for value in record] ) + def delete_record( + self, + record_identifier_value_as_string: str, + event_time: str, + ): + """Delete a single record from a FeatureGroup. + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + event_time (String): + a timestamp format String indicating when the deletion event occurred. + """ + return self.sagemaker_session.delete_record( + feature_group_name=self.name, + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + def ingest( self, data_frame: DataFrame, diff --git a/src/sagemaker/feature_store/feature_store.py b/src/sagemaker/feature_store/feature_store.py new file mode 100644 index 0000000000..def8b2b2da --- /dev/null +++ b/src/sagemaker/feature_store/feature_store.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Store. + +Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and +manage features for machine learning (ML) models. +""" +from __future__ import absolute_import + +import datetime +from typing import Any, Dict, Sequence, Union + +import attr +import pandas as pd + +from sagemaker import Session +from sagemaker.feature_store.dataset_builder import DatasetBuilder +from sagemaker.feature_store.feature_group import FeatureGroup + + +@attr.s +class FeatureStore: + """FeatureStore definition. + + This class instantiates a FeatureStore object that comprises a SageMaker session instance. + + Attributes: + sagemaker_session (Session): session instance to perform boto calls. + """ + + sagemaker_session: Session = attr.ib(default=Session) + + def create_dataset( + self, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, + ) -> DatasetBuilder: + """Create a Dataset Builder for generating a Dataset. + + Args: + base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a + Dataset. + output_path (str): An S3 URI which stores the output .csv file. + record_identifier_feature_name (str): A string representing the record identifier + feature if base is a DataFrame (default: None). + event_time_identifier_feature_name (str): A string representing the event time + identifier feature if base is a DataFrame (default: None). + included_feature_names (List[str]): A list of features to be included in the output + (default: None). + kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + + Raises: + ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event + time identifier feature name is provided. + """ + if isinstance(base, pd.DataFrame): + if record_identifier_feature_name is None or event_time_identifier_feature_name is None: + raise ValueError( + "You must provide a record identifier feature name and an event time " + + "identifier feature name if specify DataFrame as base." + ) + return DatasetBuilder( + self.sagemaker_session, + base, + output_path, + record_identifier_feature_name, + event_time_identifier_feature_name, + included_feature_names, + kms_key_id, + ) + + def list_feature_groups( + self, + name_contains: str = None, + feature_group_status_equals: str = None, + offline_store_status_equals: str = None, + creation_time_after: datetime.datetime = None, + creation_time_before: datetime.datetime = None, + sort_order: str = None, + sort_by: str = None, + max_results: int = None, + next_token: str = None, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + return self.sagemaker_session.list_feature_groups( + name_contains=name_contains, + feature_group_status_equals=feature_group_status_equals, + offline_store_status_equals=offline_store_status_equals, + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + sort_order=sort_order, + sort_by=sort_by, + max_results=max_results, + next_token=next_token, + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3fc4fc1256..72df570496 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -312,7 +312,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. for key in keys: - tail_s3_uri_path = os.path.basename(key_prefix) + tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) @@ -4341,6 +4341,56 @@ def update_feature_group( FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions ) + def list_feature_groups( + self, + name_contains, + feature_group_status_equals, + offline_store_status_equals, + creation_time_after, + creation_time_before, + sort_order, + sort_by, + max_results, + next_token, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + list_feature_groups_args = {} + + def check_object(key, value): + if value is not None: + list_feature_groups_args[key] = value + + check_object("NameContains", name_contains) + check_object("FeatureGroupStatusEquals", feature_group_status_equals) + check_object("OfflineStoreStatusEquals", offline_store_status_equals) + check_object("CreationTimeAfter", creation_time_after) + check_object("CreationTimeBefore", creation_time_before) + check_object("SortOrder", sort_order) + check_object("SortBy", sort_by) + check_object("MaxResults", max_results) + check_object("NextToken", next_token) + + return self.sagemaker_client.list_feature_groups(**list_feature_groups_args) + def update_feature_metadata( self, feature_group_name: str, @@ -4408,6 +4458,48 @@ def put_record( Record=record, ) + def delete_record( + self, + feature_group_name: str, + record_identifier_value_as_string: str, + event_time: str, + ): + """Deletes a single record from the FeatureGroup. + + Args: + feature_group_name (str): name of the FeatureGroup. + record_identifier_value_as_string (str): name of the record identifier. + event_time (str): a timestamp indicating when the deletion event occurred. + """ + return self.sagemaker_featurestore_runtime_client.delete_record( + FeatureGroupName=feature_group_name, + RecordIdentifierValueAsString=record_identifier_value_as_string, + EventTime=event_time, + ) + + def get_record( + self, + record_identifier_value_as_string: str, + feature_group_name: str, + feature_names: Sequence[str], + ) -> Dict[str, Sequence[Dict[str, str]]]: + """Gets a single record in the FeatureGroup. + + Args: + record_identifier_value_as_string (str): name of the record identifier. + feature_group_name (str): name of the FeatureGroup. + feature_names (Sequence[str]): list of feature names. + """ + get_record_args = { + "FeatureGroupName": feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + } + + if feature_names: + get_record_args["FeatureNames"] = feature_names + + return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args) + def start_query_execution( self, catalog: str, diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index c1b84117c3..e19cebdca4 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -14,6 +14,7 @@ import json import time +import datetime from contextlib import contextmanager import boto3 @@ -24,6 +25,7 @@ from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.feature_store.feature_store import FeatureStore from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout @@ -80,6 +82,11 @@ def feature_group_name(): return f"my-feature-group-{int(time.time() * 10**7)}" +@pytest.fixture +def base_name(): + return f"my-base-{int(time.time() * 10**7)}" + + @pytest.fixture def offline_store_s3_uri(feature_store_session, region_name): bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}" @@ -107,6 +114,32 @@ def pandas_data_frame(): return df +@pytest.fixture +def base_dataframe(): + base_data = [ + [1, 187512346.0, 123, 128], + [2, 187512347.0, 168, 258], + [3, 187512348.0, 125, 184], + [1, 187512349.0, 195, 206], + ] + return pd.DataFrame( + base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"] + ) + + +@pytest.fixture +def feature_group_dataframe(): + feature_group_data = [ + [1, 187512246.0, 456, 325], + [2, 187512247.0, 729, 693], + [3, 187512348.0, 129, 901], + [1, 187512449.0, 289, 286], + ] + return pd.DataFrame( + feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"] + ) + + @pytest.fixture def pandas_data_frame_without_string(): df = pd.DataFrame( @@ -288,6 +321,92 @@ def test_create_feature_group_glue_table_format( assert table_format == "Glue" +def test_get_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + record_names = list(map(lambda r: r.feature_name, record)) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + removed_feature_name = record_names.pop() + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=record_names, + ) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + assert feature["FeatureName"] is not removed_feature_name + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string="1.0", + ) + assert retrieved_record is None + + +def test_delete_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is not None + # Delete data + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=datetime.datetime.now().replace(microsecond=0).isoformat() + "Z", + ) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is None + + def test_update_feature_group( feature_store_session, role, @@ -316,6 +435,25 @@ def test_update_feature_group( assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) +def test_list_feature_groups(feature_store_session, role, feature_group_name, pandas_data_frame): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + output = feature_store.list_feature_groups(name_contains=feature_group_name) + + assert output["FeatureGroupSummaries"][0]["FeatureGroupName"] == feature_group_name + + def test_feature_metadata( feature_store_session, role, @@ -420,6 +558,242 @@ def test_ingest_multi_process( assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") +def test_create_dataset_with_feature_group_base( + feature_store_session, + region_name, + role, + base_name, + feature_group_name, + offline_store_s3_uri, + base_dataframe, + feature_group_dataframe, +): + base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + with cleanup_feature_group(base), cleanup_feature_group(feature_group): + _create_feature_group_and_ingest_data( + base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role + ) + _create_feature_group_and_ingest_data( + feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role + ) + base_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, base, offline_store_s3_uri + ) + feature_group_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, feature_group, offline_store_s3_uri + ) + + with timeout(minutes=10) and cleanup_offline_store( + base_table_name, feature_store_session + ) and cleanup_offline_store(feature_group_table_name, feature_store_session): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + df, query_string = ( + feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) + .with_number_of_recent_records_by_record_identifier(4) + .with_feature_group(feature_group) + .to_dataframe() + ) + sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True) + merged_df = base_dataframe.merge( + feature_group_dataframe, left_on="base_id", right_on="fg_id" + ) + + expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True) + + expect_df.rename( + columns={ + "fg_id": "fg_id.1", + "fg_time": "fg_time.1", + "fg_feature_1": "fg_feature_1.1", + "fg_feature_2": "fg_feature_2.1", + }, + inplace=True, + ) + + assert sorted_df.equals(expect_df) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id"\n' + + 'ORDER BY origin_base."base_time" DESC,' + ' origin_base."api_invocation_time" DESC,' + ' origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2"\n' + + "FROM (\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + 'WHERE deleted_base."base_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + "AND (\n" + + 'table_base."base_time" > deleted_base."base_time"\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" >' + ' deleted_base."api_invocation_time")\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" =' + ' deleted_base."api_invocation_time" AND' + ' table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id"\n' + + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,' + ' origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",' + ' table_0."fg_feature_2"\n' + + "FROM (\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + 'WHERE deleted_0."fg_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + "AND (\n" + + 'table_0."fg_time" > deleted_0."fg_time"\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" >' + ' deleted_0."api_invocation_time")\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" =' + ' deleted_0."api_invocation_time" AND table_0."write_time" >' + ' deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + "SELECT base_id, base_time, base_feature_1, base_feature_2," + ' "fg_id.1", "fg_time.1", "fg_feature_1.1",' + ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time," + " fg_base.base_feature_1, fg_base.base_feature_2," + ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",' + ' fg_0."fg_feature_1" as "fg_feature_1.1",' + ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()' + " OVER (\n" + + 'PARTITION BY fg_base."base_id"\n' + + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."base_id" = fg_0."fg_id"\n' + + ")\n" + + "WHERE row_recent <= 4" + ) + + +def _create_feature_group_and_ingest_data( + feature_group: FeatureGroup, + dataframe: DataFrame, + offline_store_s3_uri: str, + record_identifier_name: str, + event_time_name: str, + role: str, +): + feature_group.load_feature_definitions(data_frame=dataframe) + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name=record_identifier_name, + event_time_feature_name=event_time_name, + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False) + ingestion_manager.wait() + assert 0 == len(ingestion_manager.failed_rows) + + +def _get_athena_table_name_after_data_replication( + feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri +): + feature_group_metadata = feature_group.describe() + resolved_output_s3_uri = ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("S3StorageConfig", None) + .get("ResolvedOutputS3Uri", None) + ) + s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "") + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + while True: + objects_in_bucket = s3_client.list_objects( + Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix + ) + if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1: + break + else: + print(f"Waiting for {feature_group.name} data in offline store...") + time.sleep(60) + print(f"{feature_group.name} data available.") + return ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("DataCatalogConfig", None) + .get("TableName", None) + ) + + def _wait_for_feature_group_create(feature_group: FeatureGroup): status = feature_group.describe().get("FeatureGroupStatus") while status == "Creating": @@ -451,5 +825,31 @@ def cleanup_feature_group(feature_group: FeatureGroup): finally: try: feature_group.delete() + print(f"{feature_group.name} is deleted") except Exception: raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") + + +@contextmanager +def cleanup_offline_store(table_name: str, feature_store_session: Session): + try: + yield + finally: + try: + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + account_id = feature_store_session.account_id() + bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}" + response = s3_client.list_objects_v2( + Bucket=bucket_name, + Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/", + ) + files_in_folder = response["Contents"] + files_to_delete = [] + for f in files_in_folder: + files_to_delete.append({"Key": f["Key"]}) + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete}) + except Exception: + raise RuntimeError(f"Failed to delete data under {table_name}") diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..0e55b86bd0 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py @@ -0,0 +1,612 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime + +import pandas as pd +import pytest +import os +from mock import Mock, patch + +from sagemaker.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, +) +from sagemaker.feature_store.feature_group import ( + FeatureDefinition, + FeatureGroup, + FeatureTypeEnum, +) + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def feature_group_mock(): + return Mock() + + +@pytest.fixture +def read_csv_mock(): + return Mock() + + +@pytest.fixture +def to_csv_file_mock(): + return Mock() + + +@pytest.fixture +def remove_mock(): + return Mock() + + +BASE = FeatureGroupToBeMerged( + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + "catalog", + "database", + "base-table", + "target-feature", + FeatureDefinition("other-feature", FeatureTypeEnum.STRING), + None, + TableType.FEATURE_GROUP, +) +FEATURE_GROUP = FeatureGroupToBeMerged( + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + "catalog", + "database", + "table-name", + "feature-1", + FeatureDefinition("feature-2", FeatureTypeEnum.FRACTIONAL), + "target-feature", + TableType.FEATURE_GROUP, +) + + +def test_with_feature_group_throw_runtime_error(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}} + with pytest.raises(RuntimeError) as error: + dataset_builder.with_feature_group( + feature_group, "target-feature", ["feature-1", "feature-2"] + ) + assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error) + + +def test_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + feature_group.load_feature_definitions(dataframe) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + record_identifier_feature_name="target-feature", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"]) + assert len(dataset_builder._feature_groups_to_be_merged) == 1 + assert dataset_builder._feature_groups_to_be_merged[0].features == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].database == "database" + assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table" + assert ( + dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name + == "feature-1" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_name + == "feature-2" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_type + == FeatureTypeEnum.STRING + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base + == "target-feature" + ) + + +def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.point_in_time_accurate_join() + assert dataset_builder._point_in_time_accurate_join + + +def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_duplicated_records() + assert dataset_builder._include_duplicated_records + + +def test_include_deleted_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_deleted_records() + assert dataset_builder._include_deleted_records + + +def test_with_number_of_recent_records_by_record_identifier( + sagemaker_session_mock, feature_group_mock +): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_recent_records_by_record_identifier(5) + assert dataset_builder._number_of_recent_records == 5 + + +def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_records_from_query_results(100) + assert dataset_builder._number_of_records == 100 + + +def test_with_event_time_range(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + start = datetime.datetime.now() + end = start + datetime.timedelta(minutes=1) + dataset_builder.with_event_time_range(start, end) + assert dataset_builder._event_time_starting_timestamp == start + assert dataset_builder._event_time_ending_timestamp == end + + +def test_to_csv_file_not_support_base_type(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + with pytest.raises(ValueError) as error: + dataset_builder.to_csv_file() + assert "Base must be either a FeatureGroup or a DataFrame." in str(error) + + +def test_to_csv_file_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3-file-path"}, + "Query": "query-string", + } + } + file_path, query_string = dataset_builder.to_csv_file() + assert file_path == "s3-file-path" + assert query_string == "query-string" + + +@patch("pandas.DataFrame.to_csv") +@patch("pandas.read_csv") +@patch("os.remove") +def test_to_dataframe_with_dataframe( + remove_mock, read_csv_mock, to_csv_file_mock, sagemaker_session_mock +): + dataframe = pd.DataFrame({"feature-1": [420, 380.0, 390], "feature-2": [50, 40.0, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="s3://file/to/path", + event_time_identifier_feature_name="feature-2", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3://s3-file-path"}, + "Query": "query-string", + } + } + to_csv_file_mock.return_value = None + read_csv_mock.return_value = dataframe + os.remove.return_value = None + df, query_string = dataset_builder.to_dataframe() + assert df.equals(dataframe) + assert query_string == "query-string" + + +def test_construct_where_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + time = datetime.datetime.now().replace(microsecond=0) + start = time + datetime.timedelta(minutes=1) + end = start + datetime.timedelta(minutes=1) + dataset_builder._write_time_ending_timestamp = time + dataset_builder._event_time_starting_timestamp = start + dataset_builder._event_time_ending_timestamp = end + query_string = dataset_builder._construct_where_query_string( + "suffix", + FeatureDefinition("event-time", FeatureTypeEnum.STRING), + ["NOT is_deleted"], + ) + assert ( + query_string + == "WHERE NOT is_deleted\n" + + f"AND table_suffix.\"write_time\" <= to_timestamp('{time}', " + + "'yyyy-mm-dd hh24:mi:ss')\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") >= ' + + f"from_unixtime({start.timestamp()})\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") <= ' + + f"from_unixtime({end.timestamp()})" + ) + + +def test_construct_query_string_with_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder._include_duplicated_records = True + + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" DESC, ' + + 'origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + 'FROM "database"."base-table" table_base\n' + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + 'JOIN "database"."base-table" table_base\n' + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + 'FROM "database"."table-name" table_0\n' + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + 'JOIN "database"."table-name" table_0\n' + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" > ' + + 'deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" = ' + + 'deleted_0."api_invocation_time" AND table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + ")\n" + ) + + +def test_construct_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + dataset_builder._point_in_time_accurate_join = True + dataset_builder._event_time_identifier_feature_name = "target-feature" + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" ' + + 'DESC, origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND ' + + 'table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + 'AND from_unixtime(fg_base."target-feature") >= from_unixtime(fg_0."feature-2")\n' + + ")\n" + ) + + +def test_create_temp_table(sagemaker_session_mock): + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + dataset_builder._create_temp_table("table-name", "s3-folder") + assert sagemaker_session_mock.start_query_execution.call_count == 1 + sagemaker_session_mock.start_query_execution.assert_called_once_with( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + query_string="CREATE EXTERNAL TABLE table-name (feature-1 INT, feature-2 INT) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + 'WITH SERDEPROPERTIES ("separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\") ' + + "LOCATION 's3-folder';", + output_location="file/to/path", + kms_key=None, + ) + + +@pytest.mark.parametrize( + "column, expected", + [ + ("feature-1", "feature-1 STRING"), + ("feature-2", "feature-2 INT"), + ("feature-3", "feature-3 DOUBLE"), + ("feature-4", "feature-4 BOOLEAN"), + ("feature-5", "feature-5 TIMESTAMP"), + ], +) +def test_construct_athena_table_column_string(column, expected, sagemaker_session_mock): + dataframe = pd.DataFrame( + { + "feature-1": ["420"], + "feature-2": [50], + "feature-3": [5.0], + "feature-4": [True], + "feature-5": [pd.Timestamp(1513393355)], + } + ) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + query_string = dataset_builder._construct_athena_table_column_string(column) + assert query_string == expected + + +def test_construct_athena_table_column_string_not_support_column_type( + sagemaker_session_mock, +): + dataframe = pd.DataFrame({"feature": pd.Series([1] * 3, dtype="int8")}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + with pytest.raises(RuntimeError) as error: + dataset_builder._construct_athena_table_column_string("feature") + assert "The dataframe type int8 is not supported yet." in str(error) + + +def test_run_query_throw_runtime_error(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + with pytest.raises(RuntimeError) as error: + dataset_builder._run_query("query-string", "catalog", "database") + assert "Failed to execute query query-id." in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py new file mode 100644 index 0000000000..dce38fe426 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -0,0 +1,580 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pandas as pd +import pytest +from mock import Mock, patch, MagicMock +from botocore.exceptions import ProfileNotFound + +from sagemaker.feature_store.feature_definition import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + FeatureTypeEnum, +) +from sagemaker.feature_store.feature_group import ( + FeatureGroup, + IngestionManagerPandas, + AthenaQuery, + IngestionError, +) +from sagemaker.feature_store.inputs import FeatureParameter + + +class PicklableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +@pytest.fixture +def s3_uri(): + return "s3://some/uri" + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def fs_runtime_client_config_mock(): + return PicklableMock() + + +@pytest.fixture +def feature_group_dummy_definitions(): + return [ + FractionalFeatureDefinition(feature_name="feature1"), + IntegralFeatureDefinition(feature_name="feature2"), + StringFeatureDefinition(feature_name="feature3"), + ] + + +@pytest.fixture +def create_table_ddl(): + return ( + "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + " feature1 FLOAT\n" + " feature2 INT\n" + " feature3 STRING\n" + " write_time TIMESTAMP\n" + " event_time TIMESTAMP\n" + " is_deleted BOOLEAN\n" + ")\n" + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + "LOCATION 's3://resolved_output_s3_uri'" + ) + + +def test_feature_store_create( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + offline_store_config={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": s3_uri}, + }, + ) + + +def test_feature_store_create_online_only( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + ) + + +def test_feature_store_delete(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.delete() + sagemaker_session_mock.delete_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup" + ) + + +def test_feature_store_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe() + sagemaker_session_mock.describe_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", next_token=None + ) + + +def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(feature_group_dummy_definitions) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + ) + + +def test_feature_metadata_update(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + parameter_additions = [FeatureParameter(key="key1", value="value1")] + parameter_removals = ["key2"] + + feature_group.update_feature_metadata( + feature_name="Feature1", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[pa.to_dict() for pa in parameter_additions], + parameter_removals=parameter_removals, + ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[], + parameter_removals=[], + ) + + +def test_feature_metadata_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe_feature_metadata(feature_name="Feature1") + sagemaker_session_mock.describe_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", feature_name="Feature1" + ) + + +def test_get_record(sagemaker_session_mock): + feature_group_name = "MyFeatureGroup" + feature_names = ["MyFeature1", "MyFeature2"] + record_identifier_value_as_string = "1.0" + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session_mock) + feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + sagemaker_session_mock.get_record.assert_called_with( + feature_group_name=feature_group_name, + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + + +def test_put_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.put_record(record=[]) + sagemaker_session_mock.put_record.assert_called_with( + feature_group_name="MyFeatureGroup", record=[] + ) + + +def test_delete_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + record_identifier_value_as_string = "1.0" + event_time = "2022-09-14" + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + sagemaker_session_mock.delete_record.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + + +def test_load_feature_definition(sagemaker_session_mock): + feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "string": pd.Series(["f1"], dtype="string"), + } + ) + feature_definitions = feature_group.load_feature_definitions(data_frame=df) + names = [fd.feature_name for fd in feature_definitions] + types = [fd.feature_type for fd in feature_definitions] + assert names == ["float", "int", "string"] + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + ] + + +def test_load_feature_definition_unsupported_types(sagemaker_session_mock): + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "bool": pd.Series([True], dtype="bool"), + } + ) + with pytest.raises(ValueError) as error: + feature_group.load_feature_definitions(data_frame=df) + assert "Failed to infer Feature type based on dtype bool for column bool." in str(error) + + +def test_ingest_zero_processes(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) + + assert "max_processes must be greater than 0." in str(error) + + +def test_ingest_zero_workers(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) + + assert "max_workers must be greater than 0." in str(error) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10) + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name=None, + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest_with_profile_name( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name="profile_name", + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +def test_as_hive_ddl_with_default_values( + create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock +): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert ( + create_table_ddl.format( + database="sagemaker_featurestore", + table_name="MyGroup", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) + == feature_group.as_hive_ddl() + ) + + +def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert create_table_ddl.format( + database="MyDatabase", + table_name="MyTable", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", + MagicMock(), +) +def test_ingestion_manager_run_success(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + ) + manager.run(df) + + manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", + PicklableMock(return_value=[]), +) +def test_ingestion_manager_run_multi_process_with_multi_thread_success( + fs_runtime_client_config_mock, +): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=2, + max_processes=2, + ) + manager.run(df) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(return_value=[1]), +) +def test_ingestion_manager_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1] + assert manager.failed_rows == [1] + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(side_effect=ProfileNotFound(profile="non_exist")), +) +def test_ingestion_manager_with_profile_name_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + profile_name="non_exist", + ) + + try: + manager.run(df) + except Exception as e: + assert "The config profile (non_exist) could not be found" in str(e) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + PicklableMock(return_value=[1]), +) +def test_ingestion_manager_run_multi_process_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=None, + max_workers=2, + max_processes=2, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1, 1, 1, 1] + assert manager.failed_rows == [1, 1, 1, 1] + + +@pytest.fixture +def query(sagemaker_session_mock): + return AthenaQuery( + catalog="catalog", + database="database", + table_name="table_name", + sagemaker_session=sagemaker_session_mock, + ) + + +def test_athena_query_run(sagemaker_session_mock, query): + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} + query.run( + query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup" + ) + sagemaker_session_mock.start_query_execution.assert_called_with( + catalog="catalog", + database="database", + query_string="query", + output_location="s3://some-bucket/some-path", + kms_key=None, + workgroup="workgroup", + ) + assert "some-bucket" == query._result_bucket + assert "some-path" == query._result_file_prefix + assert "query_id" == query._current_query_execution_id + + +def test_athena_query_wait(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.wait() + sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") + + +def test_athena_query_get_query_execution(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.get_query_execution() + sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +@patch("pandas.read_csv") +def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + query._current_query_execution_id = "query_id" + query._result_bucket = "bucket" + query._result_file_prefix = "prefix" + query.as_dataframe() + sagemaker_session_mock.download_athena_query_result.assert_called_with( + bucket="bucket", + prefix="prefix", + query_execution_id="query_id", + filename="tmp/query_id.csv", + ) + read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Failed to execute query query_id" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "QUEUED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "RUNNING"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 92ba35573c..073daca9ea 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -10,46 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import datetime import pandas as pd import pytest -from mock import Mock, patch, MagicMock -from botocore.exceptions import ProfileNotFound - -from sagemaker.feature_store.feature_definition import ( - FractionalFeatureDefinition, - IntegralFeatureDefinition, - StringFeatureDefinition, - FeatureTypeEnum, -) -from sagemaker.feature_store.feature_group import ( - FeatureGroup, - IngestionManagerPandas, - AthenaQuery, - IngestionError, -) -from sagemaker.feature_store.inputs import ( - FeatureParameter, - TableFormatEnum, -) - +from mock import Mock -class PicklableMock(Mock): - def __reduce__(self): - return (Mock, ()) +from sagemaker.feature_store.feature_store import FeatureStore - -@pytest.fixture -def role_arn(): - return "arn:role" - - -@pytest.fixture -def s3_uri(): - return "s3://some/uri" +DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]}) @pytest.fixture @@ -58,558 +29,108 @@ def sagemaker_session_mock(): @pytest.fixture -def fs_runtime_client_config_mock(): - return PicklableMock() - - -@pytest.fixture -def feature_group_dummy_definitions(): - return [ - FractionalFeatureDefinition(feature_name="feature1"), - IntegralFeatureDefinition(feature_name="feature2"), - StringFeatureDefinition(feature_name="feature3"), - ] - - -@pytest.fixture -def create_table_ddl(): - return ( - "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" - " feature1 FLOAT\n" - " feature2 INT\n" - " feature3 STRING\n" - " write_time TIMESTAMP\n" - " event_time TIMESTAMP\n" - " is_deleted BOOLEAN\n" - ")\n" - "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" - " STORED AS\n" - " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" - " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" - "LOCATION 's3://resolved_output_s3_uri'" - ) - - -def test_feature_store_create( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_iceberg_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.ICEBERG, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Iceberg", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_glue_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.GLUE, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Glue", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_online_only( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=False, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - ) - - -def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.delete() - sagemaker_session_mock.delete_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup" - ) - - -def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe() - sagemaker_session_mock.describe_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", next_token=None - ) - - -def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.update(feature_group_dummy_definitions) - sagemaker_session_mock.update_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], - ) - - -def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - - parameter_additions = [FeatureParameter(key="key1", value="value1")] - parameter_removals = ["key2"] - - feature_group.update_feature_metadata( - feature_name="Feature1", - description="TestDescription", - parameter_additions=parameter_additions, - parameter_removals=parameter_removals, - ) - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[pa.to_dict() for pa in parameter_additions], - parameter_removals=parameter_removals, - ) - feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[], - parameter_removals=[], - ) - - -def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe_feature_metadata(feature_name="Feature1") - sagemaker_session_mock.describe_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", feature_name="Feature1" - ) - - -def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.put_record(record=[]) - sagemaker_session_mock.put_record.assert_called_with( - feature_group_name="MyFeatureGroup", record=[] - ) - - -def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "string": pd.Series(["f1"], dtype="string"), - } - ) - feature_definitions = feature_group.load_feature_definitions(data_frame=df) - names = [fd.feature_name for fd in feature_definitions] - types = [fd.feature_type for fd in feature_definitions] - assert names == ["float", "int", "string"] - assert types == [ - FeatureTypeEnum.FRACTIONAL, - FeatureTypeEnum.INTEGRAL, - FeatureTypeEnum.STRING, - ] +def feature_group_mock(): + return Mock() -def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "object": pd.Series(["f1"], dtype="object"), - } - ) +def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + output_path="file/to/path", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._output_path == "file/to/path" + + +def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=DATAFRAME, + record_identifier_feature_name="feature_1", + event_time_identifier_feature_name="feature_2", + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base.equals(DATAFRAME) + assert dataset_builder._record_identifier_feature_name == "feature_1" + assert dataset_builder._event_time_identifier_feature_name == "feature_2" + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) with pytest.raises(ValueError) as error: - feature_group.load_feature_definitions(data_frame=df) - assert "Failed to infer Feature type based on dtype object for column object." in str(error) - - -def test_ingest_zero_processes(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) - - assert "max_processes must be greater than 0." in str(error) - - -def test_ingest_zero_workers(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) - - assert "max_workers must be greater than 0." in str(error) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10) - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name=None, - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest_with_profile_name( - ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock -): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name="profile_name", - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -def test_as_hive_ddl_with_default_values( - create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock -): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert ( - create_table_ddl.format( - database="sagemaker_featurestore", - table_name="MyGroup", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", + feature_store.create_dataset( + base=DATAFRAME, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", ) - == feature_group.as_hive_ddl() - ) - - -def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert create_table_ddl.format( - database="MyDatabase", - table_name="MyTable", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", - ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", - MagicMock(), -) -def test_ingestion_manager_run_success(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - ) - manager.run(df) - - manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", - PicklableMock(return_value=[]), -) -def test_ingestion_manager_run_multi_process_with_multi_thread_success( - fs_runtime_client_config_mock, -): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=2, - max_processes=2, - ) - manager.run(df) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(return_value=[1]), -) -def test_ingestion_manager_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1] - assert manager.failed_rows == [1] - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(side_effect=ProfileNotFound(profile="non_exist")), -) -def test_ingestion_manager_with_profile_name_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - profile_name="non_exist", - ) - - try: - manager.run(df) - except Exception as e: - assert "The config profile (non_exist) could not be found" in str(e) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - PicklableMock(return_value=[1]), -) -def test_ingestion_manager_run_multi_process_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=None, - max_workers=2, - max_processes=2, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1, 1, 1, 1] - assert manager.failed_rows == [1, 1, 1, 1] - - -@pytest.fixture -def query(sagemaker_session_mock): - return AthenaQuery( - catalog="catalog", - database="database", - table_name="table_name", - sagemaker_session=sagemaker_session_mock, - ) - - -def test_athena_query_run(sagemaker_session_mock, query): - WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run( - query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP - ) - sagemaker_session_mock.start_query_execution.assert_called_with( - catalog="catalog", - database="database", - query_string="query", - output_location="s3://some-bucket/some-path", - kms_key=None, - workgroup=WORKGROUP, - ) - assert "some-bucket" == query._result_bucket - assert "some-path" == query._result_file_prefix - assert "query_id" == query._current_query_execution_id - - -def test_athena_query_wait(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") - - -def test_athena_query_get_query_execution(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -@patch("pandas.read_csv") -def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "SUCCEEDED"}} - } - query._current_query_execution_id = "query_id" - query._result_bucket = "bucket" - query._result_file_prefix = "prefix" - query.as_dataframe() - sagemaker_session_mock.download_athena_query_result.assert_called_with( - bucket="bucket", - prefix="prefix", - query_execution_id="query_id", - filename="tmp/query_id.csv", + assert ( + "You must provide a record identifier feature name and an event time identifier feature " + + "name if specify DataFrame as base." + in str(error) + ) + + +def test_list_feature_groups_with_no_filter(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups() + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ) + + +def test_list_feature_groups_with_all_filters(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", ) - read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "FAILED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Failed to execute query query_id" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "QUEUED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "RUNNING"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index bf81283177..d7c94470f5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2787,6 +2787,35 @@ def test_feature_metadata_describe(sagemaker_session): ) +def test_list_feature_groups(sagemaker_session): + expected_list_feature_groups_args = { + "NameContains": "MyFeatureGroup", + "FeatureGroupStatusEquals": "Created", + "OfflineStoreStatusEquals": "Active", + "CreationTimeAfter": datetime.datetime(2020, 12, 1), + "CreationTimeBefore": datetime.datetime(2022, 7, 1), + "SortOrder": "Ascending", + "SortBy": "Name", + "MaxResults": 50, + "NextToken": "token", + } + sagemaker_session.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + assert sagemaker_session.sagemaker_client.list_feature_groups.called_once() + assert sagemaker_session.sagemaker_client.list_feature_groups.called_with( + **expected_list_feature_groups_args + ) + + def test_start_query_execution(sagemaker_session): athena_mock = Mock() sagemaker_session.boto_session.client( From fb3880f804854d8456682c4aa17de321cb5a89f9 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:14 +0000 Subject: [PATCH 39/58] prepare release v2.122.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b66e85f54..de20a8a0df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.122.0 (2022-12-14) + +### Features + + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + +### Bug Fixes and Other Changes + + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + ## v2.121.2 (2022-12-12) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 8fde5e282f..202f672bab 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.3.dev0 +2.122.0 From a584ea5ff73ea5b6df8eec749069ec86adf2e8fc Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:15 +0000 Subject: [PATCH 40/58] update development version to v2.122.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 202f672bab..6d7f044fa2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.122.0 +2.122.1.dev0 From 93a846670f57f444b590551b2d67a3c6a95302aa Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Wed, 14 Dec 2022 09:09:46 -0800 Subject: [PATCH 41/58] feature: Add SageMaker Experiment (#3536) * feature: Add experiment plus Run class (#691) * feature: Add Experiment helper classes (#646) * feature: Add Experiment helper classes feature: Add helper class _RunEnvironment * change: Change sleep retry to backoff retry for get TC * minor fixes in backoff retry Co-authored-by: Dewen Qi * feature: Add helper classes and methods for Run class (#660) * feature: Add helper classes and methods for Run class * Add Parent class to address comment * fix docstyle check * Add arg docstrings in _helper Co-authored-by: Dewen Qi * feature: Add Experiment Run class (#651) Co-authored-by: Dewen Qi * change: Add integ tests for Run (#673) Co-authored-by: Dewen Qi * Update run log metric to use MetricsManager (#678) * Update run.log_metric to use _MetricsManager * fix several metrics issues * Add doc strings to metrics.py Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Dewen Qi Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> * change: Simplify exp plus integ test configuration (#694) Co-authored-by: Dewen Qi * feature: add RunName to expeirment_config (#696) * change: Update Run init and add Run load and _RunContext (#707) * change: Update Run init and add Run load Add exp name and run group name to load and address comments * Address nit comments Co-authored-by: Dewen Qi * fix: Fix run name uniqueness issue (#730) Co-authored-by: Dewen Qi * change: Update integ tests for Exp Plus M1 changes (#741) Co-authored-by: Dewen Qi * add metrics client to session object (#745) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * change: Add integ test for using Run in Transform Job (#749) Co-authored-by: Dewen Qi * Add async metrics sink (#739) Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * use metrics client provided by session (#754) * fix flaky metrics test (#753) * change: Change Run.init and Run.load to constructor and module method respectively (#752) Co-authored-by: Dewen Qi * feature: Add latest metric service model (#757) Co-authored-by: Dewen Qi Co-authored-by: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> * fix: lowercase run name (#767) * Change: Minimize use of lower case tc name (#769) * change: Clean up test resources to remove model files (#756) * change: Clean up test resources to remove model files * fix: Change experiment enums to upper case * change: Upgrade boto3 and update test to validate mixed case name * fix: Update as per latest botocore release and backend change Co-authored-by: Dewen Qi * lowercase trial component name (#776) * change: Expose sagemaker experiment doc strings * fix: Fix exp name mixed case in issue Co-authored-by: Dewen Qi Co-authored-by: Dana Benson Co-authored-by: Dana Benson <31262102+danabens@users.noreply.github.com> Co-authored-by: Yifei Zhu <66866419+yzhu0@users.noreply.github.com> --- .gitignore | 5 +- doc/experiments/index.rst | 10 + doc/experiments/sagemaker.experiments.rst | 20 + doc/index.rst | 10 + requirements/extras/test_requirements.txt | 1 + setup.py | 2 +- src/sagemaker/amazon/amazon_estimator.py | 7 +- src/sagemaker/apiutils/_base_types.py | 6 +- src/sagemaker/apiutils/_boto_functions.py | 4 +- src/sagemaker/dataset_definition/inputs.py | 6 +- src/sagemaker/estimator.py | 16 +- src/sagemaker/experiments/__init__.py | 20 + src/sagemaker/experiments/_api_types.py | 251 +++++ src/sagemaker/experiments/_environment.py | 132 +++ src/sagemaker/experiments/_helper.py | 266 +++++ src/sagemaker/experiments/_metrics.py | 413 ++++++++ src/sagemaker/experiments/_run_context.py | 58 ++ src/sagemaker/experiments/_utils.py | 218 ++++ src/sagemaker/experiments/experiment.py | 237 +++++ src/sagemaker/experiments/run.py | 882 ++++++++++++++++ src/sagemaker/experiments/trial.py | 289 ++++++ src/sagemaker/experiments/trial_component.py | 341 +++++++ src/sagemaker/lineage/_utils.py | 17 - src/sagemaker/lineage/artifact.py | 3 +- src/sagemaker/processing.py | 9 +- src/sagemaker/session.py | 23 +- src/sagemaker/transformer.py | 7 +- src/sagemaker/utilities/search_expression.py | 133 +++ src/sagemaker/utils.py | 66 ++ tests/data/experiment/inference.py | 85 ++ .../process_job_script_for_run_clz.py | 37 + .../train_job_script_for_run_clz.py | 71 ++ .../transform_job_materials/data.csv | 1 + .../transform_job_materials/xgb_model.tar.gz | Bin 0 -> 35946 bytes tests/integ/sagemaker/experiments/__init__.py | 0 tests/integ/sagemaker/experiments/conftest.py | 177 ++++ tests/integ/sagemaker/experiments/helpers.py | 42 + .../sagemaker/experiments/test_experiment.py | 56 ++ .../sagemaker/experiments/test_metrics.py | 39 + tests/integ/sagemaker/experiments/test_run.py | 662 ++++++++++++ .../integ/sagemaker/experiments/test_trial.py | 75 ++ .../experiments/test_trial_component.py | 144 +++ tests/integ/sagemaker/lineage/conftest.py | 5 +- tests/integ/sagemaker/lineage/helpers.py | 14 - .../integ/sagemaker/lineage/test_artifact.py | 4 +- tests/integ/sagemaker/utilities/__init__.py | 0 .../utilities/test_search_expression.py | 67 ++ tests/integ/test_marketplace.py | 4 +- tests/integ/test_multidatamodel.py | 21 +- tests/integ/utils.py | 20 + tests/unit/conftest.py | 66 ++ tests/unit/sagemaker/experiments/__init__.py | 0 tests/unit/sagemaker/experiments/conftest.py | 86 ++ tests/unit/sagemaker/experiments/helpers.py | 44 + .../sagemaker/experiments/test_environment.py | 107 ++ .../sagemaker/experiments/test_experiment.py | 306 ++++++ .../unit/sagemaker/experiments/test_helper.py | 195 ++++ .../sagemaker/experiments/test_metrics.py | 178 ++++ tests/unit/sagemaker/experiments/test_run.py | 941 ++++++++++++++++++ .../sagemaker/experiments/test_run_context.py | 191 ++++ .../unit/sagemaker/experiments/test_trial.py | 276 +++++ .../experiments/test_trial_component.py | 384 +++++++ .../unit/sagemaker/experiments/test_utils.py | 36 + .../sagemaker/huggingface/test_estimator.py | 1 + .../sagemaker/tensorflow/test_estimator.py | 1 + .../test_huggingface_pytorch_compiler.py | 1 + .../test_huggingface_tensorflow_compiler.py | 1 + .../test_tensorflow_compiler.py | 1 + .../utilities/test_search_expression.py | 80 ++ .../workflow/test_clarify_check_step.py | 44 - .../unit/sagemaker/workflow/test_entities.py | 43 - .../workflow/test_quality_check_step.py | 46 - tests/unit/sagemaker/workflow/test_steps.py | 47 +- tests/unit/test_amazon_estimator.py | 13 +- tests/unit/test_estimator.py | 9 +- tests/unit/test_mxnet.py | 1 + tests/unit/test_pytorch.py | 1 + tests/unit/test_rl.py | 1 + tests/unit/test_session.py | 15 + tests/unit/test_sklearn.py | 1 + tests/unit/test_utils.py | 64 +- tests/unit/test_xgboost.py | 1 + 82 files changed, 7894 insertions(+), 263 deletions(-) create mode 100644 doc/experiments/index.rst create mode 100644 doc/experiments/sagemaker.experiments.rst create mode 100644 src/sagemaker/experiments/__init__.py create mode 100644 src/sagemaker/experiments/_api_types.py create mode 100644 src/sagemaker/experiments/_environment.py create mode 100644 src/sagemaker/experiments/_helper.py create mode 100644 src/sagemaker/experiments/_metrics.py create mode 100644 src/sagemaker/experiments/_run_context.py create mode 100644 src/sagemaker/experiments/_utils.py create mode 100644 src/sagemaker/experiments/experiment.py create mode 100644 src/sagemaker/experiments/run.py create mode 100644 src/sagemaker/experiments/trial.py create mode 100644 src/sagemaker/experiments/trial_component.py create mode 100644 src/sagemaker/utilities/search_expression.py create mode 100644 tests/data/experiment/inference.py create mode 100644 tests/data/experiment/process_job_script_for_run_clz.py create mode 100644 tests/data/experiment/train_job_script_for_run_clz.py create mode 100644 tests/data/experiment/transform_job_materials/data.csv create mode 100644 tests/data/experiment/transform_job_materials/xgb_model.tar.gz create mode 100644 tests/integ/sagemaker/experiments/__init__.py create mode 100644 tests/integ/sagemaker/experiments/conftest.py create mode 100644 tests/integ/sagemaker/experiments/helpers.py create mode 100644 tests/integ/sagemaker/experiments/test_experiment.py create mode 100644 tests/integ/sagemaker/experiments/test_metrics.py create mode 100644 tests/integ/sagemaker/experiments/test_run.py create mode 100644 tests/integ/sagemaker/experiments/test_trial.py create mode 100644 tests/integ/sagemaker/experiments/test_trial_component.py create mode 100644 tests/integ/sagemaker/utilities/__init__.py create mode 100644 tests/integ/sagemaker/utilities/test_search_expression.py create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/sagemaker/experiments/__init__.py create mode 100644 tests/unit/sagemaker/experiments/conftest.py create mode 100644 tests/unit/sagemaker/experiments/helpers.py create mode 100644 tests/unit/sagemaker/experiments/test_environment.py create mode 100644 tests/unit/sagemaker/experiments/test_experiment.py create mode 100644 tests/unit/sagemaker/experiments/test_helper.py create mode 100644 tests/unit/sagemaker/experiments/test_metrics.py create mode 100644 tests/unit/sagemaker/experiments/test_run.py create mode 100644 tests/unit/sagemaker/experiments/test_run_context.py create mode 100644 tests/unit/sagemaker/experiments/test_trial.py create mode 100644 tests/unit/sagemaker/experiments/test_trial_component.py create mode 100644 tests/unit/sagemaker/experiments/test_utils.py create mode 100644 tests/unit/sagemaker/utilities/test_search_expression.py diff --git a/.gitignore b/.gitignore index 9829ed9781..cae8f890ea 100644 --- a/.gitignore +++ b/.gitignore @@ -30,5 +30,6 @@ env/ .vscode/ **/tmp .python-version -**/_repack_model.py -**/_repack_script_launcher.sh \ No newline at end of file +**/_repack_script_launcher.sh +tests/data/**/_repack_model.py +tests/data/experiment/sagemaker-dev-1.0.tar.gz diff --git a/doc/experiments/index.rst b/doc/experiments/index.rst new file mode 100644 index 0000000000..8c12f30edc --- /dev/null +++ b/doc/experiments/index.rst @@ -0,0 +1,10 @@ +############################ +Amazon SageMaker Experiments +############################ + +The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally. + +.. toctree:: + :maxdepth: 2 + + sagemaker.experiments diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst new file mode 100644 index 0000000000..f0776ec43b --- /dev/null +++ b/doc/experiments/sagemaker.experiments.rst @@ -0,0 +1,20 @@ +Experiments +============ + +Run +------------- + +.. autoclass:: sagemaker.experiments.Run + :members: + +.. automethod:: sagemaker.experiments.load_run + +.. automethod:: sagemaker.experiments.list_runs + +.. autoclass:: sagemaker.experiments.SortByType + :members: + :undoc-members: + +.. autoclass:: sagemaker.experiments.SortOrderType + :members: + :undoc-members: diff --git a/doc/index.rst b/doc/index.rst index 2d4ebe32c1..69038056b0 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub workflows/index +**************************** +Amazon SageMaker Experiments +**************************** +You can use Amazon SageMaker Experiments to track machine learning experiments. + +.. toctree:: + :maxdepth: 2 + + experiments/index + ************************* Amazon SageMaker Debugger ************************* diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index fe93fd4d0e..494b6dca11 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -20,3 +20,4 @@ requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 pandas>=1.3.5,<1.5 +scikit-learn==1.0.2 diff --git a/setup.py b/setup.py index 4327045760..e2adb6b433 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ def read_requirements(filename): # Declare minimal set for installation required_packages = [ "attrs>=20.3.0,<23", - "boto3>=1.26.20,<2.0", + "boto3>=1.26.28,<2.0", "google-pasta", "numpy>=1.9.0,<2.0", "protobuf>=3.1,<4.0", diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index b156f2e65f..1abea5e48c 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -27,7 +27,7 @@ from sagemaker.deprecations import renamed_warning from sagemaker.estimator import EstimatorBase, _TrainingJob from sagemaker.inputs import FileSystemInput, TrainingInput -from sagemaker.utils import sagemaker_timestamp +from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable @@ -242,8 +242,8 @@ def fit( generates a default job name, based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -255,6 +255,7 @@ def fit( """ self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new( self, records, experiment_config=experiment_config ) diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py index e920797b18..9a7359e12b 100644 --- a/src/sagemaker/apiutils/_base_types.py +++ b/src/sagemaker/apiutils/_base_types.py @@ -173,8 +173,10 @@ def _search( search_items = search_method_response.get("Results", []) next_token = search_method_response.get(boto_next_token_name) for item in search_items: - if cls.__name__ in item: - yield search_item_factory(item[cls.__name__]) + # _TrialComponent class in experiments module is not public currently + class_name = cls.__name__.lstrip("_") + if class_name in item: + yield search_item_factory(item[class_name]) if not next_token: break except StopIteration: diff --git a/src/sagemaker/apiutils/_boto_functions.py b/src/sagemaker/apiutils/_boto_functions.py index 1e29f2ebea..a227d30ca8 100644 --- a/src/sagemaker/apiutils/_boto_functions.py +++ b/src/sagemaker/apiutils/_boto_functions.py @@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type): api_type, is_collection = member_name_to_type[member_name] if is_collection: if isinstance(boto_value, dict): - member_value = api_type.from_boto(boto_value) + member_value = { + key: api_type.from_boto(value) for key, value in boto_value.items() + } else: member_value = [api_type.from_boto(item) for item in boto_value] else: diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py index 90a272c4d7..468be22ac3 100644 --- a/src/sagemaker/dataset_definition/inputs.py +++ b/src/sagemaker/dataset_definition/inputs.py @@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject): """DatasetDefinition input.""" _custom_boto_types = { - "redshift_dataset_definition": (RedshiftDatasetDefinition, True), - "athena_dataset_definition": (AthenaDatasetDefinition, True), + # RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection + # Instead they are singleton objects. Thus, set the is_collection flag to False. + "redshift_dataset_definition": (RedshiftDatasetDefinition, False), + "athena_dataset_definition": (AthenaDatasetDefinition, False), } def __init__( diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 6f729267de..e3b06950aa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -79,6 +79,7 @@ get_config_value, name_from_base, to_string, + check_and_get_run_experiment_config, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -1103,8 +1104,8 @@ def fit( job_name (str): Training job name. If not specified, the estimator generates a default job name based on the training image name and current timestamp. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -1122,6 +1123,7 @@ def fit( """ self._prepare_for_training(job_name=job_name) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config) self.jobs.append(self.latest_training_job) if wait: @@ -2023,8 +2025,8 @@ def start_new(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2033,6 +2035,7 @@ def start_new(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started training job. @@ -2053,8 +2056,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config): inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`. experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -2063,6 +2066,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config): * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. Returns: Dict: dict for `sagemaker.session.Session.train` method diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..b87656b1ab --- /dev/null +++ b/src/sagemaker/experiments/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Sagemaker Experiment Module""" +from __future__ import absolute_import + +from sagemaker.experiments.run import Run # noqa: F401 +from sagemaker.experiments.run import load_run # noqa: F401 +from sagemaker.experiments.run import list_runs # noqa: F401 +from sagemaker.experiments.run import SortOrderType # noqa: F401 +from sagemaker.experiments.run import SortByType # noqa: F401 diff --git a/src/sagemaker/experiments/_api_types.py b/src/sagemaker/experiments/_api_types.py new file mode 100644 index 0000000000..78f82565aa --- /dev/null +++ b/src/sagemaker/experiments/_api_types.py @@ -0,0 +1,251 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains API objects for SageMaker experiments.""" +from __future__ import absolute_import + +import enum +import numbers + +from sagemaker.apiutils import _base_types + + +class TrialComponentMetricSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + metric_name (str): The name of the metric. + source_arn (str): The ARN of the source. + time_stamp (datetime): Metric last updated value. + max (float): The max value of the metric. + min (float): The min value of the metric. + last (float): The last value of the metric. + count (float): The number of samples used to generate the metric. + avg (float): The average value of the metric. + std_dev (float): The standard deviation of the metric. + """ + + metric_name = None + source_arn = None + time_stamp = None + max = None + min = None + last = None + count = None + avg = None + std_dev = None + + def __init__(self, metric_name=None, source_arn=None, **kwargs): + super(TrialComponentMetricSummary, self).__init__( + metric_name=metric_name, source_arn=source_arn, **kwargs + ) + + +class TrialComponentParameters(_base_types.ApiObject): + """A dictionary of TrialComponentParameterValues""" + + @classmethod + def from_boto(cls, boto_dict, **kwargs): + """Converts a boto dict to a dictionary of TrialComponentParameterValues + + Args: + boto_dict (dict): boto response dictionary. + **kwargs: Arbitrary keyword arguments. + + Returns: + dict: Dictionary of parameter values. + """ + return_map = {} + for key, value in boto_dict.items(): + return_map[key] = value.get("NumberValue", value.get("StringValue", None)) + return return_map + + @classmethod + def to_boto(cls, parameters): + """Converts TrialComponentParameters to dict. + + Args: + parameters (TrialComponentParameters): Dictionary to convert. + + Returns: + dict: Dictionary of trial component parameters in boto format. + """ + boto_map = {} + for key, value in parameters.items(): + if isinstance(value, numbers.Number): + boto_map[key] = {"NumberValue": value} + else: + boto_map[key] = {"StringValue": str(value)} + return boto_map + + +class TrialComponentArtifact(_base_types.ApiObject): + """Trial component artifact. + + Attributes: + value (str): The artifact value. + media_type (str): The media type. + """ + + value = None + media_type = None + + def __init__(self, value=None, media_type=None, **kwargs): + super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs) + + +class _TrialComponentStatusType(enum.Enum): + """The type of trial component status""" + + InProgress = "InProgress" + Completed = "Completed" + Failed = "Failed" + + +class TrialComponentStatus(_base_types.ApiObject): + """Status of the trial component. + + Attributes: + primary_status (str): The status of a trial component. + message (str): Status message. + """ + + primary_status = None + message = None + + def __init__(self, primary_status=None, message=None, **kwargs): + super(TrialComponentStatus, self).__init__( + primary_status=primary_status, message=message, **kwargs + ) + + +class TrialComponentSummary(_base_types.ApiObject): + """Summary model of a trial component. + + Attributes: + trial_component_name (str): Name of trial component. + trial_component_arn (str): ARN of the trial component. + display_name (str): Friendly display name in UI. + source_arn (str): ARN of the trial component source. + status (str): Status. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + """ + + _custom_boto_types = { + "status": (TrialComponentStatus, False), + } + trial_component_name = None + trial_component_arn = None + display_name = None + source_arn = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + + +class TrialComponentSource(_base_types.ApiObject): + """Trial Component Source + + Attributes: + source_arn (str): The ARN of the source. + """ + + source_arn = None + + def __init__(self, source_arn=None, **kwargs): + super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs) + + +class Parent(_base_types.ApiObject): + """The trial/experiment/run that a trial component is associated with. + + Attributes: + trial_name (str): Name of the trial. + experiment_name (str): Name of the experiment. + run_name (str): Name of the run. + """ + + trial_name = None + experiment_name = None + run_name = None + + +class TrialComponentSearchResult(_base_types.ApiObject): + """Summary model of an Trial Component search result. + + Attributes: + trial_component_arn (str): ARN of the trial component. + trial_component_name (str): Name of the trial component. + display_name (str): Display name of the trial component for UI display. + source (dict): The source of the trial component. + status (dict): The status of the trial component. + start_time (datetime): Start time. + end_time (datetime): End time. + creation_time (datetime): Creation time. + created_by (str): Created by. + last_modified_time (datetime): Date last modified. + last_modified_by (datetime): User last modified. + parameters (dict): The hyperparameters of the component. + input_artifacts (dict): The input artifacts of the component. + output_artifacts (dict): The output artifacts of the component. + metrics (list): The metrics for the component. + source_detail (dict): The source of the trial component. + tags (list): The list of tags that are associated with the trial component. + parents (list[Parent]): The parent of trial component. + """ + + _custom_boto_types = { + "parents": (Parent, True), # parents is a collection (list) of Parent objects + } + trial_component_arn = None + trial_component_name = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + source_detail = None + tags = None + parents = None + + +class TrialSummary(_base_types.ApiObject): + """Summary model of a trial. + + Attributes: + trial_arn (str): The ARN of the trial. + trial_name (str): The name of the trial. + creation_time (datetime): When the trial was created. + last_modified_time (datetime): When the trial was last modified. + """ + + trial_arn = None + trial_name = None + creation_time = None + last_modified_time = None diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py new file mode 100644 index 0000000000..441661ae5a --- /dev/null +++ b/src/sagemaker/experiments/_environment.py @@ -0,0 +1,132 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the _RunEnvironment class.""" +from __future__ import absolute_import + +import enum +import json +import logging +import os + +from sagemaker.experiments import trial_component +from sagemaker.utils import retry_with_backoff + +TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN" +PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json" +TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH" +MAX_RETRY_ATTEMPTS = 7 + +logger = logging.getLogger(__name__) + + +class _EnvironmentType(enum.Enum): + """SageMaker jobs which data can be pulled from the environment.""" + + SageMakerTrainingJob = 1 + SageMakerProcessingJob = 2 + SageMakerTransformJob = 3 + + +class _RunEnvironment(object): + """Retrieves job specific data from the environment.""" + + def __init__(self, environment_type, source_arn): + """Init for _RunEnvironment. + + Args: + environment_type (_EnvironmentType): The environment type. + source_arn (str): The ARN of the current job. + """ + self.environment_type = environment_type + self.source_arn = source_arn + + @classmethod + def load( + cls, + training_job_arn_env=TRAINING_JOB_ARN_ENV, + processing_job_config_path=PROCESSING_JOB_CONFIG_PATH, + transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR, + ): + """Loads source arn of current job from environment. + + Args: + training_job_arn_env (str): The environment key for training job ARN + (default: `TRAINING_JOB_ARN`). + processing_job_config_path (str): The processing job config path + (default: `/opt/ml/config/processingjobconfig.json`). + transform_job_batch_var (str): The environment variable indicating if + it is a transform job (default: `SAGEMAKER_BATCH`). + + Returns: + _RunEnvironment: Job data loaded from the environment. None if config does not exist. + """ + if training_job_arn_env in os.environ: + environment_type = _EnvironmentType.SageMakerTrainingJob + source_arn = os.environ.get(training_job_arn_env) + return _RunEnvironment(environment_type, source_arn) + if os.path.exists(processing_job_config_path): + environment_type = _EnvironmentType.SageMakerProcessingJob + source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"] + return _RunEnvironment(environment_type, source_arn) + if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true": + environment_type = _EnvironmentType.SageMakerTransformJob + # TODO: need to figure out how to get source_arn from job env + # with Transform team's help. + source_arn = "" + return _RunEnvironment(environment_type, source_arn) + + return None + + def get_trial_component(self, sagemaker_session): + """Retrieves the trial component from the job in the environment. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + _TrialComponent: The trial component created from the job. None if not found. + """ + # TODO: Remove this condition check once we have a way to retrieve source ARN + # from transform job env + if self.environment_type == _EnvironmentType.SageMakerTransformJob: + logger.error( + "Currently getting the job trial component from the transform job environment " + "is not supported. Returning None." + ) + return None + + def _get_trial_component(): + summaries = list( + trial_component._TrialComponent.list( + source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session + ) + ) + if summaries: + summary = summaries[0] + return trial_component._TrialComponent.load( + trial_component_name=summary.trial_component_name, + sagemaker_session=sagemaker_session, + ) + return None + + job_tc = None + try: + job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS) + except Exception as ex: # pylint: disable=broad-except + logger.error( + "Failed to get trail component in the current environment due to %s", str(ex) + ) + return job_tc diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py new file mode 100644 index 0000000000..0c689b1125 --- /dev/null +++ b/src/sagemaker/experiments/_helper.py @@ -0,0 +1,266 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the helper classes for SageMaker Experiment.""" +from __future__ import absolute_import + +import json +import logging +import os + +import botocore + +from sagemaker.experiments._utils import is_already_exist_error + +logger = logging.getLogger(__name__) + + +_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts" +_DEFAULT_ARTIFACT_TYPE = "Tracker" + + +class _ArtifactUploader(object): + """Artifact uploader""" + + def __init__( + self, + trial_component_name, + sagemaker_session, + artifact_bucket=None, + artifact_prefix=_DEFAULT_ARTIFACT_PREFIX, + ): + """Initialize a `_ArtifactUploader` instance. + + Args: + trial_component_name (str): The name of the trial component, + which is used to generate the S3 path to upload the artifact to. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + artifact_bucket (str): The S3 bucket to upload the artifact to. + If not specified, the default bucket defined in `sagemaker_session` + will be used. + artifact_prefix (str): The S3 key prefix used to generate the S3 path + to upload the artifact to (default: "trial-component-artifacts"). + """ + self.sagemaker_session = sagemaker_session + self.trial_component_name = trial_component_name + self.artifact_bucket = artifact_bucket + self.artifact_prefix = artifact_prefix + self._s3_client = self.sagemaker_session.boto_session.client("s3") + + def upload_artifact(self, file_path): + """Upload an artifact file to S3. + + Args: + file_path (str): the file path of the artifact + + Returns: + (str, str): The s3 URI of the uploaded file and the etag of the file. + + Raises: + ValueError: If file does not exist. + """ + file_path = os.path.expanduser(file_path) + if not os.path.isfile(file_path): + raise ValueError( + "{} does not exist or is not a file. Please supply a file path.".format(file_path) + ) + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + artifact_name = os.path.basename(file_path) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None): + """Upload an artifact object to S3. + + Args: + artifact_name (str): the name of the artifact. + artifact_object (obj): the object of the artifact + file_extension (str): Optional file extension. + + Returns: + str: The s3 URI of the uploaded file and the version of the file. + """ + if not self.artifact_bucket: + self.artifact_bucket = self.sagemaker_session.default_bucket() + if file_extension: + artifact_name = ( + artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension + ) + artifact_s3_key = "{}/{}/{}".format( + self.artifact_prefix, self.trial_component_name, artifact_name + ) + self._s3_client.put_object( + Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key + ) + etag = self._try_get_etag(artifact_s3_key) + return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag + + def _try_get_etag(self, key): + """Get ETag of given key and return None if not allowed + + Args: + key (str): The S3 object key. + + Returns: + str: The S3 object ETag if it allows, otherwise return None. + """ + try: + response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key) + return response["ETag"] + except botocore.exceptions.ClientError as error: + # requires read permissions + logger.warning("Failed to get ETag of %s due to %s", key, error) + return None + + +class _LineageArtifactManager(object): + """A helper class to manage Lineage Artifacts""" + + def __init__( + self, + name, + source_uri, + etag, + source_arn=None, + dest_arn=None, + artifact_type=_DEFAULT_ARTIFACT_TYPE, + ): + """Initialize a `_LineageArtifactManager` instance. + + Args: + name (str): The name of the Lineage artifact to be created. + source_uri (str): The source URI used to create the Lineage artifact. + etag (str): The S3 Etag used to create the Lineage artifact. + source_arn (str): The source ARN of a trail component to associate + this Lineage artifact with (default: None). + dest_arn (str): The destination ARN of a trial component to associate + this Lineage artifact with (default: None). + artifact_type (str): The type of the Lineage artifact (default: "Tracker"). + """ + self.name = name + self.source_uri = source_uri + self.etag = etag + self.source_arn = source_arn + self.dest_arn = dest_arn + self.artifact_arn = None + self.artifact_type = artifact_type + + def create_artifact(self, sagemaker_session): + """Create the artifact by calling `CreateArtifact` API + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_ids = [] + if self.etag: + source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag}) + + try: + response = sagemaker_session.sagemaker_client.create_artifact( + ArtifactName=self.name, + ArtifactType=self.artifact_type, + Source={"SourceUri": self.source_uri, "SourceTypes": source_ids}, + ) + self.artifact_arn = response["ArtifactArn"] + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip creating the artifact since it already exists: %s", err_info["Message"] + ) + + def add_association(self, sagemaker_session): + """Associate the artifact with a source/destination ARN (e.g. trial component arn) + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + source_arn = self.source_arn if self.source_arn else self.artifact_arn + dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn + # if the trial component (job) is the source then it produced the artifact, + # otherwise the artifact contributed to the trial component (job) + association_edge_type = "Produced" if self.source_arn else "ContributedTo" + try: + sagemaker_session.sagemaker_client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type + ) + except botocore.exceptions.ClientError as err: + err_info = err.response["Error"] + if not is_already_exist_error(err_info): + raise + logger.warning( + "Skip associating since the association already exists: %s", err_info["Message"] + ) + + +class _LineageArtifactTracker(object): + """Lineage Artifact Tracker""" + + def __init__(self, trial_component_arn, sagemaker_session): + """Initialize a `_LineageArtifactTracker` instance. + + Args: + trial_component_arn (str): The ARN of the trial component to be + associated with the input/output artifacts. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + """ + self.trial_component_arn = trial_component_arn + self.sagemaker_session = sagemaker_session + self.artifacts = [] + + def add_input_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage input artifact locally + + Args: + name (str): The name of the Lineage input artifact to be added. + source_uri (str): The source URI used to create the Lineage input artifact. + etag (str): The S3 Etag used to create the Lineage input artifact. + artifact_type (str): The type of the Lineage input artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def add_output_artifact(self, name, source_uri, etag, artifact_type): + """Add a Lineage output artifact locally + + Args: + name (str): The name of the Lineage output artifact to be added. + source_uri (str): The source URI used to create the Lineage output artifact. + etag (str): The S3 Etag used to create the Lineage output artifact. + artifact_type (str): The type of the Lineage output artifact. + """ + artifact = _LineageArtifactManager( + name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type + ) + self.artifacts.append(artifact) + + def save(self): + """Persist any artifact data saved locally""" + for artifact in self.artifacts: + artifact.create_artifact(self.sagemaker_session) + artifact.add_association(self.sagemaker_session) diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py new file mode 100644 index 0000000000..f80c43f337 --- /dev/null +++ b/src/sagemaker/experiments/_metrics.py @@ -0,0 +1,413 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains classes to manage metrics for Sagemaker Experiment""" +from __future__ import absolute_import + +import datetime +import json +import logging +import os +import time +import threading +import queue + +import dateutil.tz + +from sagemaker.session import Session + +METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".") +METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds +METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds + +BATCH_SIZE = 10 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# TODO: remove this _SageMakerFileMetricsWriter class +# when _MetricsManager is fully ready +class _SageMakerFileMetricsWriter(object): + """Write metric data to file.""" + + def __init__(self, metrics_file_path=None): + """Construct a `_SageMakerFileMetricsWriter` object""" + self._metrics_file_path = metrics_file_path + self._file = None + self._closed = False + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Write a metric to file. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + + Raises: + SageMakerMetricsWriterException: If the metrics file is closed. + AttributeError: If file has been initialized and the writer hasn't been closed. + """ + raw_metric_data = _RawMetricData( + metric_name=metric_name, value=value, timestamp=timestamp, step=step + ) + try: + logger.debug("Writing metric: %s", raw_metric_data) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + except AttributeError as attr_err: + if self._closed: + raise SageMakerMetricsWriterException("log_metric called on a closed writer") + if not self._file: + self._file = open(self._get_metrics_file_path(), "a", buffering=1) + self._file.write(json.dumps(raw_metric_data.to_record())) + self._file.write("\n") + else: + raise attr_err + + def close(self): + """Closes the metric file.""" + if not self._closed and self._file: + self._file.close() + self._file = None # invalidate reference, causing subsequent log_metric to fail. + self._closed = True + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.close() + + def __del__(self): + """Execute self.close()""" + self.close() + + def _get_metrics_file_path(self): + """Get file path to store metrics""" + pid_filename = "{}.json".format(str(os.getpid())) + metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename) + logger.debug("metrics_file_path = %s", metrics_file_path) + return metrics_file_path + + +class SageMakerMetricsWriterException(Exception): + """SageMakerMetricsWriterException""" + + def __init__(self, message, errors=None): + """Construct a `SageMakerMetricsWriterException` instance""" + super().__init__(message) + if errors: + self.errors = errors + + +class _RawMetricData(object): + """A Raw Metric Data Object""" + + MetricName = None + Value = None + Timestamp = None + Step = None + + def __init__(self, metric_name, value, timestamp=None, step=None): + """Construct a `_RawMetricData` instance. + + Args: + metric_name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime or float or str): Timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): Iteration number of the metric (default: None). + """ + if timestamp is None: + timestamp = time.time() + elif isinstance(timestamp, datetime.datetime): + # If the input is a datetime then convert it to UTC time. + # Assume a naive datetime is in local timezone + if not timestamp.tzinfo: + timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal()) + timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc) + timestamp = timestamp.timestamp() + else: + timestamp = float(timestamp) + + if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > ( + time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW + ): + raise ValueError( + "Supplied timestamp %f is invalid." + " Timestamps must be between two weeks before and two hours from now." % timestamp + ) + value = float(value) + + self.MetricName = metric_name + self.Value = float(value) + self.Timestamp = timestamp + if step is not None: + if not isinstance(step, int): + raise ValueError("step must be int.") + self.Step = step + + def to_record(self): + """Convert the `_RawMetricData` object to dict""" + return self.__dict__ + + def to_raw_metric_data(self): + """Converts the metric data to a BatchPutMetrics RawMetricData item""" + # Convert timestamp from float to timestamp str. + # Otherwise will get ParamValidationError + raw_metric_data = { + "MetricName": self.MetricName, + "Value": self.Value, + "Timestamp": str(int(self.Timestamp)), + } + if self.Step is not None: + raw_metric_data["Step"] = int(self.Step) + return raw_metric_data + + def __str__(self): + """String representation of the `_RawMetricData` object.""" + return repr(self) + + def __repr__(self): + """Return a string representation of this _RawMetricData` object.""" + return "{}({})".format( + type(self).__name__, + ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]), + ) + + +class _MetricsManager(object): + """Collects metrics and sends them directly to SageMaker Metrics data plane APIs.""" + + def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None: + """Initialize a `_MetricsManager` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + sink (object): The metrics sink to use. + """ + if sink is None: + self.sink = _SyncMetricsSink( + trial_component_name, sagemaker_session.sagemaker_metrics_client + ) + else: + self.sink = sink + + def log_metric(self, metric_name, value, timestamp=None, step=None): + """Sends a metric to metrics service.""" + + metric_data = _RawMetricData(metric_name, value, timestamp, step) + self.sink.log_metric(metric_data) + + def __enter__(self): + """Return self""" + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Execute self.close()""" + self.sink.close() + + def close(self): + """Close the metrics object.""" + self.sink.close() + + +class _SyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_SyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + # this is a simplistic solution which calls BatchPutMetrics + # on the same thread as the client code + self._buffer.append(metric_data) + self._drain() + + def _drain(self, close=False): + """Pops off all metrics in the buffer and starts sending them to metrics service.""" + + if not self._buffer: + return + + if len(self._buffer) < BATCH_SIZE and not close: + return + + # pop all the available metrics + available_metrics, self._buffer = self._buffer, [] + + self._send_metrics(available_metrics) + + def _send_metrics(self, metrics): + """Calls BatchPutMetrics directly on the metrics service.""" + while metrics: + batch, metrics = ( + metrics[:BATCH_SIZE], + metrics[BATCH_SIZE:], + ) + request = self._construct_batch_put_metrics_request(batch) + response = self._metrics_client.batch_put_metrics(**request) + errors = response["Errors"] if "Errors" in response else None + if errors: + message = errors[0]["Message"] + raise Exception(f'{len(errors)} errors with message "{message}"') + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + return { + "TrialComponentName": self._trial_component_name.lower(), + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def close(self): + """Drains any remaining metrics.""" + self._drain(close=True) + + +class _MetricQueue(object): + """A thread safe queue for sending metrics to SageMaker. + + Args: + trial_component_name (str): the ARN of the resource + metric_name (str): the name of the metric + metrics_client (boto_client): the boto client for SageMaker Metrics service + """ + + _CONSUMER_SLEEP_SECONDS = 5 + + def __init__(self, trial_component_name, metric_name, metrics_client): + # infinite queue size + self._queue = queue.Queue() + self._buffer = [] + self._thread = threading.Thread(target=self._run) + self._started = False + self._finished = False + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._metric_name = metric_name + self._logged_metrics = 0 + + def log_metric(self, metric_data): + """Adds a metric data point to the queue""" + self._buffer.append(metric_data) + + if len(self._buffer) < BATCH_SIZE: + return + + self._enqueue_all() + + if not self._started: + self._thread.start() + self._started = True + + def _run(self): + """Starts the metric thread which sends metrics to SageMaker in batches""" + + while not self._queue.empty() or not self._finished: + if self._queue.empty(): + time.sleep(self._CONSUMER_SLEEP_SECONDS) + else: + batch = self._queue.get() + self._send_metrics(batch) + + def _send_metrics(self, metrics_batch): + """Calls BatchPutMetrics directly on the metrics service.""" + request = self._construct_batch_put_metrics_request(metrics_batch) + self._logged_metrics += len(metrics_batch) + self._metrics_client.batch_put_metrics(**request) + + def _construct_batch_put_metrics_request(self, batch): + """Creates dictionary object used as request to metrics service.""" + + return { + "TrialComponentName": self._trial_component_name, + "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)), + } + + def _enqueue_all(self): + """Enqueue all buffered metrics to be sent to SageMaker""" + + available_metrics, self._buffer = self._buffer, [] + if available_metrics: + self._queue.put(available_metrics) + + def close(self): + """Flushes any buffered metrics""" + + self._enqueue_all() + self._finished = True + + def is_active(self): + """Is the thread active (still draining metrics to SageMaker)""" + + return self._thread.is_alive() + + +class _AsyncMetricsSink(object): + """Collects metrics and sends them directly to metrics service.""" + + _COMPLETE_SLEEP_SECONDS = 1.0 + + def __init__(self, trial_component_name, metrics_client) -> None: + """Initialize a `_AsyncMetricsSink` instance + + Args: + trial_component_name (str): The Name of the Trial Component to log metrics to. + metrics_client (boto3.client): boto client for metrics service + """ + self._trial_component_name = trial_component_name + self._metrics_client = metrics_client + self._buffer = [] + self._is_draining = False + self._metric_queues = {} + + def log_metric(self, metric_data): + """Sends a metric to metrics service.""" + + if metric_data.MetricName in self._metric_queues: + self._metric_queues[metric_data.MetricName].log_metric(metric_data) + else: + cur_metric_queue = _MetricQueue( + self._trial_component_name, metric_data.MetricName, self._metrics_client + ) + self._metric_queues[metric_data.MetricName] = cur_metric_queue + cur_metric_queue.log_metric(metric_data) + + def close(self): + """Closes the metric file.""" + logging.debug("Closing") + for q in self._metric_queues.values(): + q.close() + + # TODO should probably use join + while any(map(lambda x: x.is_active(), self._metric_queues.values())): + time.sleep(self._COMPLETE_SLEEP_SECONDS) + logging.debug("Closed") diff --git a/src/sagemaker/experiments/_run_context.py b/src/sagemaker/experiments/_run_context.py new file mode 100644 index 0000000000..9a7dada5f4 --- /dev/null +++ b/src/sagemaker/experiments/_run_context.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment _RunContext class.""" +from __future__ import absolute_import + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sagemaker.experiments import Run + + +class _RunContext: + """A static context variable to keep track of the current Run object""" + + _context_run = None + + @classmethod + def add_run_object(cls, run: "Run"): + """Keep track of the current executing Run object + + by adding it to a class static variable. + + Args: + run (Run): The current Run object to be tracked. + """ + cls._context_run = run + + @classmethod + def drop_current_run(cls) -> "Run": + """Drop the Run object tracked in the global static variable + + as its execution finishes (its "with" block ends). + + Return: + Run: the dropped Run object. + """ + current_run = cls._context_run + cls._context_run = None + return current_run + + @classmethod + def get_current_run(cls) -> "Run": + """Return the current Run object without dropping it. + + Return: + Run: the current Run object to be returned. + """ + return cls._context_run diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py new file mode 100644 index 0000000000..5ef5d99dad --- /dev/null +++ b/src/sagemaker/experiments/_utils.py @@ -0,0 +1,218 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment utility methods.""" +from __future__ import absolute_import + +import logging +import os + +import mimetypes +import urllib +from functools import wraps +from typing import Optional + +from sagemaker import Session +from sagemaker.apiutils import _utils +from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression +from sagemaker.utils import retry_with_backoff + + +def resolve_artifact_name(file_path): + """Resolve artifact name from given file path. + + If not specified, will auto create one. + + Args: + file_path (str): Path to the file. + + Returns: + str: The resolved artifact name. + """ + _, filename = os.path.split(file_path) + if filename: + return filename + + return _utils.name("artifact") + + +def guess_media_type(file_path): + """Infer the media type of a file based on its file name. + + Args: + file_path (str): Path to the file. + + Returns: + str: The guessed media type. + """ + file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path)) + guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False) + return guessed_media_type + + +def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_attrs_name): + """Verify if lengths match between lists of true labels and predicted attributes. + + Args: + true_labels (list or array): The list of the true labels. + predicted_attrs (list or array): The list of the predicted labels/probabilities/scores. + predicted_attrs_name (str): The name of the predicted attributes. + + Raises: + ValueError: If lengths mismatch between true labels and predicted attributes. + """ + if len(true_labels) != len(predicted_attrs): + raise ValueError( + "Lengths mismatch between true labels and {}: " + "({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs)) + ) + + +def validate_invoked_inside_run_context(func): + """A Decorator to force the decorated method called under Run context.""" + + @wraps(func) + def wrapper(*args, **kwargs): + self_instance = args[0] + if not self_instance._inside_load_context and not self_instance._inside_init_context: + raise RuntimeError("This method should be called inside context of 'with' statement.") + return func(*args, **kwargs) + + return wrapper + + +def is_already_exist_error(error): + """Check if the error indicates resource already exists + + Args: + error (dict): The "Error" field in the response of the + `botocore.exceptions.ClientError` + """ + return error["Code"] == "ValidationException" and "already exists" in error["Message"] + + +def get_tc_and_exp_config_from_job_env( + environment: _RunEnvironment, + sagemaker_session: Session, +) -> dict: + """Retrieve an experiment config from the job environment. + + Args: + environment (_RunEnvironment): The run environment object with job specific data. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + job_name = environment.source_arn.split("/")[-1] + if environment.environment_type == _EnvironmentType.SageMakerTrainingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_training_job(job_name), + num_attempts=4, + ) + elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob: + job_response = retry_with_backoff( + callable_func=lambda: sagemaker_session.describe_processing_job(job_name), + num_attempts=4, + ) + else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob + raise RuntimeError( + "Failed to load the Run as loading experiment config " + "from transform job environment is not currently supported. " + "As a workaround, please explicitly pass in " + "the experiment_name and run_name in load_run." + ) + + job_exp_config = job_response.get("ExperimentConfig", dict()) + from sagemaker.experiments.run import RUN_NAME + + if job_exp_config.get(RUN_NAME, None): + return job_exp_config + raise RuntimeError( + "Not able to fetch RunName in ExperimentConfig of the sagemaker job. " + "Please make sure the ExperimentConfig is correctly set." + ) + + +def verify_load_input_names( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, +): + """Verify the run_name and the experiment_name inputs in load_run. + + Args: + run_name (str): The run_name supplied by the user (default: None). + experiment_name (str): The experiment_name supplied by the user + (default: None). + + Raises: + ValueError: If run_name is supplied while experiment_name is not. + """ + if not run_name and experiment_name: + logging.warning( + "No run_name is supplied. Ignoring the provided experiment_name " + "since it only takes effect along with run_name. " + "Will load the Run object from the job environment or current Run context." + ) + if run_name and not experiment_name: + raise ValueError( + "Invalid input: experiment_name is missing when run_name is supplied. " + "Please supply a valid experiment_name when the run_name is not None." + ) + + +def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool: + """Check if a trial component is generated by `sagemaker.experiments.Run` + + Args: + trial_component_name (str): The name of the trial component. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + bool: Indicate whether the trial component is created by + `sagemaker.experiments.Run` or not. + """ + search_filter = Filter( + name="TrialComponentName", + operator=Operator.EQUALS, + value=trial_component_name, + ) + search_expression = SearchExpression(filters=[search_filter]) + + def search(): + return list( + _TrialComponent.search( + search_expression=search_expression, + max_results=1, # TrialComponentName is unique in an account + sagemaker_session=sagemaker_session, + ) + )[0] + + try: + tc_search_res = retry_with_backoff(search, 4) + from sagemaker.experiments.run import RUN_TC_TAG + + if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags: + return False + return True + except Exception as ex: # pylint: disable=broad-except + logging.warning( + "Failed to inspect the type of the trial component (%s), due to (%s)", + trial_component_name, + str(ex), + ) + return False diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py new file mode 100644 index 0000000000..8f59ff36b3 --- /dev/null +++ b/src/sagemaker/experiments/experiment.py @@ -0,0 +1,237 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Experiment(_base_types.Record): + """An Amazon SageMaker experiment, which is a collection of related trials. + + New experiments are created by calling `experiments.experiment._Experiment.create`. + Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`. + + Attributes: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. + description (str): A description of the experiment. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment. + """ + + experiment_name = None + display_name = None + description = None + tags = None + + _boto_create_method = "create_experiment" + _boto_load_method = "describe_experiment" + _boto_update_method = "update_experiment" + _boto_delete_method = "delete_experiment" + + _boto_update_members = ["experiment_name", "description", "display_name"] + _boto_delete_members = ["experiment_name"] + + _MAX_DELETE_ALL_ATTEMPTS = 3 + + def save(self): + """Save the state of this Experiment to SageMaker. + + Returns: + dict: Update experiment API response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Experiment from SageMaker. + + Deleting an Experiment does not delete associated Trials and their Trial Components. + It requires that each Trial in the Experiment is first deleted. + + Returns: + dict: Delete experiment API response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, experiment_name, sagemaker_session=None): + """Load an existing experiment and return an `_Experiment` object representing it. + + Args: + experiment_name: (str): Name of the experiment + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_load_method, + experiment_name=experiment_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Create a new experiment in SageMaker and return an `_Experiment` object. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). + description: (str): Description of the experiment (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + return cls._construct( + cls._boto_create_method, + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, + experiment_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=None, + ): + """Load an experiment by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment. Must be unique. Required. + display_name: (str): Name of the experiment that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the + given `experiment_name` does not exist and a new experiment has to be created. + description: (str): Description of the experiment (default: None). + This is used only when the given `experiment_name` does not exist and + a new experiment has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + tags (List[Dict[str, str]]): A list of tags to associate with the experiment + (default: None). This is used only when the given `experiment_name` does not + exist and a new experiment has to be created. + + Returns: + experiments.experiment._Experiment: A SageMaker `_Experiment` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + experiment = _Experiment.load(experiment_name, sagemaker_session) + except sagemaker_client.exceptions.ResourceNotFound: + experiment = _Experiment.create( + experiment_name=experiment_name, + display_name=display_name, + description=description, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return experiment + + def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): + """List trials in this experiment matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialSummary] : + An iterator over trials matching the criteria. + """ + return _Trial.list( + experiment_name=self.experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=self.sagemaker_session, + ) + + def _delete_all(self, action): + """Force to delete the experiment and associated trials, trial components. + + Args: + action (str): The string '--force' is required to pass in to confirm recursively + delete the experiments, and all its trials and trial components. + """ + if action != "--force": + raise ValueError( + "Must confirm with string '--force' in order to delete the experiment and " + "associated trials, trial components." + ) + + delete_attempt_count = 0 + last_exception = None + while True: + if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS: + raise Exception("Failed to delete, please try again.") from last_exception + try: + for trial_summary in self.list_trials(): + trial = _Trial.load( + sagemaker_session=self.sagemaker_session, + trial_name=trial_summary.trial_name, + ) + for ( + trial_component_summary + ) in trial.list_trial_components(): # pylint: disable=no-member + tc = _TrialComponent.load( + sagemaker_session=self.sagemaker_session, + trial_component_name=trial_component_summary.trial_component_name, + ) + tc.delete(force_disassociate=True) + # to prevent throttling + time.sleep(1.2) + trial.delete() # pylint: disable=no-member + # to prevent throttling + time.sleep(1.2) + self.delete() + break + except Exception as ex: # pylint: disable=broad-except + last_exception = ex + finally: + delete_attempt_count = delete_attempt_count + 1 diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py new file mode 100644 index 0000000000..1492b6bafa --- /dev/null +++ b/src/sagemaker/experiments/run.py @@ -0,0 +1,882 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the SageMaker Experiment Run class.""" +from __future__ import absolute_import + +import datetime +import logging +from enum import Enum +from math import isnan, isinf +from numbers import Number +from typing import Optional, List, Dict, TYPE_CHECKING, Union + +import dateutil +from numpy import array + +from sagemaker.apiutils import _utils +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType +from sagemaker.experiments._helper import ( + _ArtifactUploader, + _LineageArtifactTracker, +) +from sagemaker.experiments._environment import _RunEnvironment +from sagemaker.experiments._run_context import _RunContext +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments._metrics import _MetricsManager +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + +from sagemaker.utils import ( + get_module, + unique_name_from_base, +) + +from sagemaker.experiments._utils import ( + guess_media_type, + resolve_artifact_name, + verify_length_of_true_and_predicted, + validate_invoked_inside_run_context, + get_tc_and_exp_config_from_job_env, + verify_load_input_names, + is_run_trial_component, +) + +if TYPE_CHECKING: + from sagemaker import Session + +logger = logging.getLogger(__name__) + +RUN_NAME_BASE = "Sagemaker-Run".lower() +TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}" +MAX_RUN_TC_ARTIFACTS_LEN = 30 +MAX_NAME_LEN_IN_BACKEND = 120 +EXPERIMENT_NAME = "ExperimentName" +TRIAL_NAME = "TrialName" +RUN_NAME = "RunName" +DELIMITER = "-" +RUN_TC_TAG_KEY = "sagemaker:trial-component-source" +RUN_TC_TAG_VALUE = "run" +RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE} + + +class SortByType(Enum): + """The type of property by which to sort the `list_runs` results.""" + + CREATION_TIME = "CreationTime" + NAME = "Name" + + +class SortOrderType(Enum): + """The type of order to sort the list or search results.""" + + ASCENDING = "Ascending" + DESCENDING = "Descending" + + +class Run(object): + """A collection of parameters, metrics, and artifacts to create a ML model.""" + + def __init__( + self, + experiment_name: str, + run_name: Optional[str] = None, + experiment_display_name: Optional[str] = None, + run_display_name: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, + sagemaker_session: Optional["Session"] = None, + ): + """Construct a `Run` instance. + + SageMaker Experiments automatically tracks the inputs, parameters, configurations, + and results of your iterations as runs. + You can assign, group, and organize these runs into experiments. + You can also create, compare, and evaluate runs. + + The code sample below shows how to initialize a run, log parameters to the Run object + and invoke a training job under the context of this Run object, which automatically + passes the run's ``experiment_config`` (including the experiment name, run name etc.) + to the training job. + + Note: + All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within + the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown. + + .. code:: python + + with Run(experiment_name="my-exp", run_name="my-run", ...) as run: + run.log_parameter(...) + ... + estimator.fit(job_name="my-job") # Create a training job + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + The code snippet below displays how to load the run initialized above + in a custom training job script, where no ``run_name`` or ``experiment_name`` + is presented as they are automatically retrieved from the experiment config + in the job environment. + + Note: + Instead of the ``Run`` constructor, the ``load_run`` is recommended to use + in a job script to load the existing run created before the job launch. + Otherwise, a new run may be created each time you launch a job. + + .. code:: python + + with load_run() as run: + run.log_metric(...) + ... + + Args: + experiment_name (str): The name of the experiment. The name must be unique + within an account. + run_name (str): The name of the run. If it is not specified, one is auto generated. + experiment_display_name (str): Name of the experiment that will appear in UI, + such as SageMaker Studio. (default: None). This display name is used in + a create experiment call. If an experiment with the specified name already exists, + this display name won't take effect. + run_display_name (str): The display name of the run used in UI (default: None). + This display name is used in a create run call. If a run with the + specified name already exists, this display name won't take effect. + tags (List[Dict[str, str]]): A list of tags to be used for all create calls, + e.g. to create an experiment, a run group, etc. (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + """ + # TODO: we should revert the lower casting once backend fix reaches prod + self.experiment_name = experiment_name.lower() + sagemaker_session = sagemaker_session or _utils.default_session() + self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE) + + # avoid confusion due to mis-match in casing between run name and TC name + self.run_name = self.run_name.lower() + + trial_component_name = Run._generate_trial_component_name( + run_name=self.run_name, experiment_name=self.experiment_name + ) + self.run_group_name = Run._generate_trial_name(self.experiment_name) + + self._experiment = _Experiment._load_or_create( + experiment_name=self.experiment_name, + display_name=experiment_display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial = _Trial._load_or_create( + experiment_name=self.experiment_name, + trial_name=self.run_group_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + self._trial_component, is_existed = _TrialComponent._load_or_create( + trial_component_name=trial_component_name, + display_name=run_display_name, + tags=Run._append_run_tc_label_to_tags(tags), + sagemaker_session=sagemaker_session, + ) + if is_existed: + logger.info( + "The run (%s) under experiment (%s) already exists. Loading it. " + "Note: sagemaker.experiments.load_run is recommended to use when " + "the desired run already exists.", + self.run_name, + self.experiment_name, + ) + self._trial.add_trial_component(self._trial_component) + + self._artifact_uploader = _ArtifactUploader( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._lineage_artifact_tracker = _LineageArtifactTracker( + trial_component_arn=self._trial_component.trial_component_arn, + sagemaker_session=sagemaker_session, + ) + self._metrics_manager = _MetricsManager( + trial_component_name=self._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + self._inside_init_context = False + self._inside_load_context = False + self._in_load = False + + @property + def experiment_config(self) -> dict: + """Get experiment config from run attributes.""" + return { + EXPERIMENT_NAME: self.experiment_name, + TRIAL_NAME: self.run_group_name, + RUN_NAME: self._trial_component.trial_component_name, + } + + @validate_invoked_inside_run_context + def log_parameter(self, name: str, value: Union[str, int, float]): + """Record a single parameter value for this run. + + Overwrites any previous value recorded for the specified parameter name. + + Args: + name (str): The name of the parameter. + value (str or int or float): The value of the parameter. + """ + if self._is_input_valid("parameter", name, value): + self._trial_component.parameters[name] = value + + @validate_invoked_inside_run_context + def log_parameters(self, parameters: Dict[str, Union[str, int, float]]): + """Record a collection of parameter values for this run. + + Args: + parameters (dict[str, str or int or float]): The parameters to record. + """ + filtered_parameters = { + key: value + for (key, value) in parameters.items() + if self._is_input_valid("parameter", key, value) + } + self._trial_component.parameters.update(filtered_parameters) + + @validate_invoked_inside_run_context + def log_metric( + self, + name: str, + value: float, + timestamp: Optional[datetime.datetime] = None, + step: Optional[int] = None, + ): + """Record a custom scalar metric value for this run. + + Note: + This method is for manual custom metrics, for automatic metrics see the + ``enable_sagemaker_metrics`` parameter on the ``estimator`` class. + + Args: + name (str): The name of the metric. + value (float): The value of the metric. + timestamp (datetime.datetime): The timestamp of the metric. + If not specified, the current UTC time will be used. + step (int): The integer iteration number of the metric value (default: None). + """ + if self._is_input_valid("metric", name, value): + self._metrics_manager.log_metric( + metric_name=name, value=value, timestamp=timestamp, step=step + ) + + @validate_invoked_inside_run_context + def log_precision_recall( + self, + y_true: Union[list, array], + predicted_probabilities: Union[list, array], + positive_label: Optional[Union[str, int]] = None, + title: Optional[str] = None, + is_output: bool = True, + no_skill: Optional[int] = None, + ): + """Create and log a precision recall graph artifact for Studio UI to render. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + predicted_probabilities (list or array): Estimated/predicted probabilities. + positive_label (str or int): Label of the positive class (default: None). + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + no_skill (int): The precision threshold under which the classifier cannot discriminate + between the classes and would predict a random class or a constant class in + all cases (default: None). + """ + + verify_length_of_true_and_predicted( + true_labels=y_true, + predicted_attrs=predicted_probabilities, + predicted_attrs_name="predicted probabilities", + ) + + get_module("sklearn") + from sklearn.metrics import precision_recall_curve, average_precision_score + + kwargs = {} + if positive_label is not None: + kwargs["pos_label"] = positive_label + + precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs) + + kwargs["average"] = "micro" + ap = average_precision_score(y_true, predicted_probabilities, **kwargs) + + data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": precision.tolist(), + "recall": recall.tolist(), + "averagePrecisionScore": ap, + "noSkill": no_skill, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_roc_curve( + self, + y_true: Union[list, array], + y_score: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a receiver operating characteristic (ROC curve) artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the artifact + by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_score (list or array): Estimated/predicted probabilities. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores" + ) + + get_module("sklearn") + from sklearn.metrics import roc_curve, auc + + fpr, tpr, _ = roc_curve(y_true, y_score) + + auc = auc(fpr, tpr) + + data = { + "type": "ROCCurve", + "version": 0, + "title": title, + "falsePositiveRate": fpr.tolist(), + "truePositiveRate": tpr.tolist(), + "areaUnderCurve": auc, + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_confusion_matrix( + self, + y_true: Union[list, array], + y_pred: Union[list, array], + title: Optional[str] = None, + is_output: bool = True, + ): + """Create and log a confusion matrix artifact. + + The artifact is stored in S3 and represented as a lineage artifact + with an association with the run. + + You can view the artifact in the UI. + If your job is created by a pipeline execution you can view the + artifact by selecting the corresponding step in the pipelines UI. + See also `SageMaker Pipelines `_ + This method requires sklearn library. + + Args: + y_true (list or array): True labels. If labels are not binary + then positive_label should be given. + y_pred (list or array): Predicted labels. + title (str): Title of the graph (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + verify_length_of_true_and_predicted( + true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels" + ) + + get_module("sklearn") + from sklearn.metrics import confusion_matrix + + matrix = confusion_matrix(y_true, y_pred) + + data = { + "type": "ConfusionMatrix", + "version": 0, + "title": title, + "confusionMatrix": matrix.tolist(), + } + self._log_graph_artifact( + artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output + ) + + @validate_invoked_inside_run_context + def log_artifact( + self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True + ): + """Record a single artifact for this run. + + Overwrites any previous value recorded for the specified name. + + Args: + name (str): The name of the artifact. + value (str): The value. + media_type (str): The MediaType (MIME type) of the value (default: None). + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output=is_output) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value, media_type=media_type + ) + + @validate_invoked_inside_run_context + def log_file( + self, + file_path: str, + name: Optional[str] = None, + media_type: Optional[str] = None, + is_output: bool = True, + ): + """Upload a file to s3 and store it as an input/output artifact in this run. + + Args: + file_path (str): The path of the local file to upload. + name (str): The name of the artifact (default: None). + media_type (str): The MediaType (MIME type) of the file. + If not specified, this library will attempt to infer the media type + from the file extension of ``file_path``. + is_output (bool): Determines direction of association to the + run. Defaults to True (output artifact). + If set to False then represented as input association. + """ + self._verify_trial_component_artifacts_length(is_output) + media_type = media_type or guess_media_type(file_path) + name = name or resolve_artifact_name(file_path) + s3_uri, _ = self._artifact_uploader.upload_artifact(file_path) + if is_output: + self._trial_component.output_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + else: + self._trial_component.input_artifacts[name] = TrialComponentArtifact( + value=s3_uri, media_type=media_type + ) + + def close(self): + """Persist any data saved locally.""" + try: + # Update the trial component with additions from the Run object + self._trial_component.save() + # Create Lineage entities for the artifacts + self._lineage_artifact_tracker.save() + finally: + if self._metrics_manager: + self._metrics_manager.close() + + @staticmethod + def _generate_trial_name(base_name) -> str: + """Generate the reserved trial name based on experiment name + + Args: + base_name (str): The ``experiment_name`` of this ``Run`` object. + """ + available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE) + return TRIAL_NAME_TEMPLATE.format(base_name[:available_length]) + + @staticmethod + def _is_input_valid(input_type, field_name, field_value) -> bool: + """Check if the input is valid or not + + Args: + input_type (str): The type of the input, one of ``parameter``, ``metric``. + field_name (str): The name of the field to be checked. + field_value (str or int or float): The value of the field to be checked. + """ + if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)): + logger.warning( + "Failed to log %s %s. Received invalid value: %s.", + input_type, + field_name, + field_value, + ) + return False + return True + + def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None): + """Log an artifact. + + Logs an artifact by uploading data to S3, creating an artifact, and associating that + artifact with the run trial component. + + Args: + data (dict): Artifacts data that will be saved to S3. + graph_type (str): The type of the artifact. + is_output (bool): Determines direction of association to the + trial component. Defaults to True (output artifact). + If set to False then represented as input association. + artifact_name (str): Name of the artifact (default: None). + """ + # generate an artifact name + if not artifact_name: + unique_name_from_base(graph_type) + + # create a json file in S3 + s3_uri, etag = self._artifact_uploader.upload_object_artifact( + artifact_name, data, file_extension="json" + ) + + # create an artifact and association for the table + if is_output: + self._lineage_artifact_tracker.add_output_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + else: + self._lineage_artifact_tracker.add_input_artifact( + name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type + ) + + def _verify_trial_component_artifacts_length(self, is_output): + """Verify the length of trial component artifacts + + Args: + is_output (bool): Determines direction of association to the + trial component. + + Raises: + ValueError: If the length of trial component artifacts exceeds the limit. + """ + err_msg_template = "Cannot add more than {} {}_artifacts under run" + if is_output: + if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output")) + else: + if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN: + raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input")) + + @staticmethod + def _generate_trial_component_name(run_name: str, experiment_name: str) -> str: + """Generate the TrialComponentName based on run_name and experiment_name + + Args: + run_name (str): The run_name supplied by the user. + experiment_name (str): The experiment_name supplied by the user, + which is prepended to the run_name to generate the TrialComponentName. + + Returns: + str: The TrialComponentName used to create a trial component + which is unique in an account. + + Raises: + ValueError: If either the run_name or the experiment_name exceeds + the length limit. + """ + buffer = 1 # leave length buffers for delimiters + max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer + err_msg_template = "The {} (length: {}) must have length less than or equal to {}" + if len(run_name) > max_len: + raise ValueError(err_msg_template.format("run_name", len(run_name), max_len)) + if len(experiment_name) > max_len: + raise ValueError( + err_msg_template.format("experiment_name", len(experiment_name), max_len) + ) + trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name) + # due to mixed-case concerns on the backend + trial_component_name = trial_component_name.lower() + return trial_component_name + + @staticmethod + def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str: + """Extract the user supplied run name from a trial component name. + + Args: + trial_component_name (str): The name of a run trial component. + experiment_name (str): The experiment_name supplied by the user, + which was prepended to the run_name to generate the trial_component_name. + + Returns: + str: The name of the Run object supplied by a user. + """ + return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1) + + @staticmethod + def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list: + """Append the run trial component label to tags used to create a trial component. + + Args: + tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object. + + Returns: + list: The updated tags with the appended run trial component label. + """ + if not tags: + tags = [] + tags.append(RUN_TC_TAG) + return tags + + def __enter__(self): + """Updates the start time of the run. + + Returns: + object: self. + """ + nested_with_err_msg_template = ( + "It is not allowed to use nested 'with' statements on the {}." + ) + if self._in_load: + if self._inside_load_context: + raise RuntimeError(nested_with_err_msg_template.format("load_run")) + self._inside_load_context = True + else: + if _RunContext.get_current_run(): + raise RuntimeError(nested_with_err_msg_template.format("Run")) + self._inside_init_context = True + _RunContext.add_run_object(self) + + if not self._trial_component.start_time: + start_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.start_time = start_time + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, + message="Within a run context", + ) + # Save the start_time and status changes to backend + self._trial_component.save() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Updates the end time of the run. + + Args: + exc_type (str): The exception type. + exc_value (str): The exception value. + exc_traceback (str): The stack trace of the exception. + """ + if self._in_load: + self._inside_load_context = False + self._in_load = False + else: + self._inside_init_context = False + _RunContext.drop_current_run() + + end_time = datetime.datetime.now(dateutil.tz.tzlocal()) + self._trial_component.end_time = end_time + if exc_value: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value) + ) + else: + self._trial_component.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.Completed.value + ) + + self.close() + + +def load_run( + run_name: Optional[str] = None, + experiment_name: Optional[str] = None, + sagemaker_session: Optional["Session"] = None, +) -> Run: + """Load an existing run. + + In order to reuse an existing run to log extra data, ``load_run`` is recommended. + It can be used in several ways: + + 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``. + + If ``run_name`` and ``experiment_name`` are passed in, they are honored over + the default experiment config in the job environment or the run context + (i.e. within the ``with`` block). + + Note: + Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work. + Otherwise, you may get a ``ValueError``. + + .. code:: python + + with load_run(experiment_name="my-exp", run_name="my-run") as run: + run.log_metric(...) + ... + + 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``. + + In this case, the default experiment config (specified when creating the job) is fetched + from the job environment to load the run. + + .. code:: python + + # In a job script + with load_run() as run: + run.log_metric(...) + ... + + 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block) + but without supplying ``run_name`` and ``experiment_name``. + + Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked + in the run context. Then when we call ``load_run()`` under this with statement, the ``run1`` + in the context is loaded by default. + + .. code:: python + + # In a notebook + with Run(experiment_name="my-exp", run_name="my-run", ...) as run1: + run1.log_parameter(...) + + with load_run() as run2: # run2 is the same object as run1 + run2.log_metric(...) + ... + + Args: + run_name (str): The name of the run to be loaded (default: None). + If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be + fetched to load the run. + experiment_name (str): The name of the Experiment that the to be loaded run + is associated with (default: None). + Note: the experiment_name must be supplied along with a valid run_name. + Otherwise, it will be ignored. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + Run: The loaded Run object. + """ + sagemaker_session = sagemaker_session or _utils.default_session() + environment = _RunEnvironment.load() + + verify_load_input_names(run_name=run_name, experiment_name=experiment_name) + + if run_name or environment: + if run_name: + logger.warning( + "run_name is explicitly supplied in load_run, " + "which will be prioritized to load the Run object. " + "In other words, the run name in the experiment config, fetched from the " + "job environment or the current run context, will be ignored." + ) + else: + exp_config = get_tc_and_exp_config_from_job_env( + environment=environment, sagemaker_session=sagemaker_session + ) + run_name = Run._extract_run_name_from_tc_name( + trial_component_name=exp_config[RUN_NAME], + experiment_name=exp_config[EXPERIMENT_NAME], + ) + experiment_name = exp_config[EXPERIMENT_NAME] + + run_instance = Run( + experiment_name=experiment_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) + elif _RunContext.get_current_run(): + run_instance = _RunContext.get_current_run() + else: + raise RuntimeError( + "Failed to load a Run object. " + "Please make sure a Run object has been initialized already." + ) + + run_instance._in_load = True + return run_instance + + +def list_runs( + experiment_name: str, + created_before: Optional[datetime.datetime] = None, + created_after: Optional[datetime.datetime] = None, + sagemaker_session: Optional["Session"] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + sort_by: SortByType = SortByType.CREATION_TIME, + sort_order: SortOrderType = SortOrderType.DESCENDING, +) -> list: + """Return a list of ``Run`` objects matching the given criteria. + + Args: + experiment_name (str): Only Run objects related to the specified experiment + are returned. + created_before (datetime.datetime): Return Run objects created before this instant + (default: None). + created_after (datetime.datetime): Return Run objects created after this instant + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + max_results (int): Maximum number of Run objects to retrieve (default: None). + next_token (str): Token for next page of results (default: None). + sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME + (default: CREATION_TIME). + sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING). + + Returns: + list: A list of ``Run`` objects. + """ + tc_summaries = _TrialComponent.list( + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by.value, + sort_order=sort_order.value, + sagemaker_session=sagemaker_session, + max_results=max_results, + next_token=next_token, + ) + run_list = [] + for tc_summary in tc_summaries: + if not is_run_trial_component( + trial_component_name=tc_summary.trial_component_name, + sagemaker_session=sagemaker_session, + ): + continue + run_instance = Run( + experiment_name=experiment_name, + run_name=Run._extract_run_name_from_tc_name( + trial_component_name=tc_summary.trial_component_name, + experiment_name=experiment_name, + ), + sagemaker_session=sagemaker_session, + ) + run_list.append(run_instance) + return run_list diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py new file mode 100644 index 0000000000..146b24f18b --- /dev/null +++ b/src/sagemaker/experiments/trial.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the Trial class.""" +from __future__ import absolute_import + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments.trial_component import _TrialComponent + + +class _Trial(_base_types.Record): + """An execution of a data-science workflow with an experiment. + + Consists of a list of trial component objects, which document individual + activities within the workflow. + + Attributes: + trial_name (str): The name of the trial. + experiment_name (str): The name of the trial's experiment. + display_name (str): The name of the trial that will appear in UI, + such as SageMaker Studio. + tags (List[Dict[str, str]]): A list of tags to associate with the trial. + """ + + trial_name = None + experiment_name = None + display_name = None + tags = None + + _boto_create_method = "create_trial" + _boto_load_method = "describe_trial" + _boto_delete_method = "delete_trial" + _boto_update_method = "update_trial" + + _boto_update_members = ["trial_name", "display_name"] + _boto_delete_members = ["trial_name"] + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_Trial, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this Trial to SageMaker. + + Returns: + dict: Update trial response. + """ + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self): + """Delete this Trial from SageMaker. + + Does not delete associated Trial Components. + + Returns: + dict: Delete trial response. + """ + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_name, sagemaker_session=None): + """Load an existing trial and return a `_Trial` object. + + Args: + trial_name: (str): Name of the Trial. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + return super(_Trial, cls)._construct( + cls._boto_load_method, + trial_name=trial_name, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Create a new trial and return a `_Trial` object. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). + tags (List[dict]): A list of tags to associate with the trial (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + trial = super(_Trial, cls)._construct( + cls._boto_create_method, + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial + + @classmethod + def list( + cls, + experiment_name=None, + trial_component_name=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + ): + """List all trials matching the specified criteria. + + Args: + experiment_name (str): Name of the experiment. If specified, only trials in + the experiment will be returned (default: None). + trial_component_name (str): Name of the trial component. If specified, only + trials with this trial component name will be returned (default: None). + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + Returns: + collections.Iterator[experiments._api_types.TrialSummary]: An iterator over trials + matching the specified criteria. + """ + return super(_Trial, cls)._list( + "list_trials", + _api_types.TrialSummary.from_boto, + "TrialSummaries", + experiment_name=experiment_name, + trial_component_name=trial_component_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + + def add_trial_component(self, trial_component): + """Add the specified trial component to this trial. + + A trial component may belong to many trials and a trial may have many trial components. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.associate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def remove_trial_component(self, trial_component): + """Remove the specified trial component from this trial. + + Args: + trial_component (str or _TrialComponent): The trial component to add. + Can be one of a _TrialComponent instance, or a string containing + the name of the trial component to add. + """ + if isinstance(trial_component, _TrialComponent): + trial_component_name = trial_component.trial_component_name + elif isinstance(trial_component, str): + trial_component_name = trial_component + else: + raise TypeError( + "Unsupported type of trail component {}. " + "It has to be one type of _TrialComponent or str".format(trial_component) + ) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=self.trial_name, TrialComponentName=trial_component_name + ) + + def list_trial_components( + self, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + max_results=None, + next_token=None, + ): + """List trial components in this trial matching the specified criteria. + + Args: + created_before (datetime.datetime): Return trials created before this instant + (default: None). + created_after (datetime.datetime): Return trials created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', + 'CreationTime' (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary] : An iterator over + trials matching the criteria. + """ + return _TrialComponent.list( + trial_name=self.trial_name, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + next_token=next_token, + sagemaker_session=self.sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial by name and create a new one if it does not exist. + + Args: + experiment_name: (str): Name of the experiment to create this trial in. + trial_name: (str): Name of the Trial. + display_name (str): Name of the trial that will appear in UI, + such as SageMaker Studio (default: None). This is used only when the given + `trial_name` does not exist and a new trial has to be created. + tags (List[dict]): A list of tags to associate with the trial (default: None). + This is used only when the given `trial_name` does not exist and + a new trial has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial._Trial: A SageMaker `_Trial` object + """ + sagemaker_client = sagemaker_session.sagemaker_client + try: + trial = _Trial.load(trial_name, sagemaker_session) + if trial.experiment_name != experiment_name: # pylint: disable=no-member + raise ValueError( + "The given experiment_name {} ".format(experiment_name) + + "does not match that in the loaded trial {}".format( + trial.experiment_name # pylint: disable=no-member + ) + ) + except sagemaker_client.exceptions.ResourceNotFound: + trial = _Trial.create( + experiment_name=experiment_name, + trial_name=trial_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return trial diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py new file mode 100644 index 0000000000..e5701b2119 --- /dev/null +++ b/src/sagemaker/experiments/trial_component.py @@ -0,0 +1,341 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Contains the TrialComponent class.""" +from __future__ import absolute_import + +import time + +from sagemaker.apiutils import _base_types +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import TrialComponentSearchResult + + +class _TrialComponent(_base_types.Record): + """This class represents a SageMaker trial component object. + + A trial component is a stage in a trial. + Trial components are created automatically within the SageMaker runtime and + may not be created directly. To automatically associate trial components with + a trial and experiment, supply an experiment config when creating a job. + For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html + + Attributes: + trial_component_name (str): The name of the trial component. Generated by SageMaker + from the name of the source job with a suffix specific to the type of source job. + trial_component_arn (str): The ARN of the trial component. + display_name (str): The name of the trial component that will appear in UI, + such as SageMaker Studio. + source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute. + status (str): Status of the source job. + start_time (datetime): When the source job started. + end_time (datetime): When the source job ended. + creation_time (datetime): When the source job was created. + created_by (obj): Contextual info on which account created the trial component. + last_modified_time (datetime): When the trial component was last modified. + last_modified_by (obj): Contextual info on which account last modified the trial component. + parameters (dict): Dictionary of parameters to the source job. + input_artifacts (dict): Dictionary of input artifacts. + output_artifacts (dict): Dictionary of output artifacts. + metrics (obj): Aggregated metrics for the job. + parameters_to_remove (list): The hyperparameters to remove from the component. + input_artifacts_to_remove (list): The input artifacts to remove from the component. + output_artifacts_to_remove (list): The output artifacts to remove from the component. + tags (List[Dict[str, str]]): A list of tags to associate with the trial component. + """ + + trial_component_name = None + trial_component_arn = None + display_name = None + source = None + status = None + start_time = None + end_time = None + creation_time = None + created_by = None + last_modified_time = None + last_modified_by = None + parameters = None + input_artifacts = None + output_artifacts = None + metrics = None + parameters_to_remove = None + input_artifacts_to_remove = None + output_artifacts_to_remove = None + tags = None + + _boto_load_method = "describe_trial_component" + _boto_create_method = "create_trial_component" + _boto_update_method = "update_trial_component" + _boto_delete_method = "delete_trial_component" + + _custom_boto_types = { + "source": (_api_types.TrialComponentSource, False), + "status": (_api_types.TrialComponentStatus, False), + "parameters": (_api_types.TrialComponentParameters, False), + "input_artifacts": (_api_types.TrialComponentArtifact, True), + "output_artifacts": (_api_types.TrialComponentArtifact, True), + "metrics": (_api_types.TrialComponentMetricSummary, True), + } + + _boto_update_members = [ + "trial_component_name", + "display_name", + "status", + "start_time", + "end_time", + "parameters", + "input_artifacts", + "output_artifacts", + "parameters_to_remove", + "input_artifacts_to_remove", + "output_artifacts_to_remove", + ] + _boto_delete_members = ["trial_component_name"] + + def __init__(self, sagemaker_session=None, **kwargs): + """Init for _TrialComponent""" + super().__init__(sagemaker_session, **kwargs) + self.parameters = self.parameters or {} + self.input_artifacts = self.input_artifacts or {} + self.output_artifacts = self.output_artifacts or {} + + @classmethod + def _boto_ignore(cls): + """Response fields to ignore by default.""" + return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"] + + def save(self): + """Save the state of this TrialComponent to SageMaker.""" + return self._invoke_api(self._boto_update_method, self._boto_update_members) + + def delete(self, force_disassociate=False): + """Delete this TrialComponent from SageMaker. + + Args: + force_disassociate (boolean): Indicates whether to force disassociate the + trial component with the trials before deletion (default: False). + If set to true, force disassociate the trial component with associated trials + first, then delete the trial component. + If it's not set or set to false, it will delete the trial component directory + without disassociation. + + Returns: + dict: Delete trial component response. + """ + if force_disassociate: + next_token = None + + while True: + if next_token: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name, NextToken=next_token + ) + else: + list_trials_response = self.sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=self.trial_component_name + ) + + # Disassociate the trials and trial components + for per_trial in list_trials_response["TrialSummaries"]: + # to prevent DisassociateTrialComponent throttling + time.sleep(1.2) + self.sagemaker_session.sagemaker_client.disassociate_trial_component( + TrialName=per_trial["TrialName"], + TrialComponentName=self.trial_component_name, + ) + + if "NextToken" in list_trials_response: + next_token = list_trials_response["NextToken"] + else: + break + + return self._invoke_api(self._boto_delete_method, self._boto_delete_members) + + @classmethod + def load(cls, trial_component_name, sagemaker_session=None): + """Load an existing trial component and return an `_TrialComponent` object representing it. + + Args: + trial_component_name (str): Name of the trial component + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object + """ + trial_component = cls._construct( + cls._boto_load_method, + trial_component_name=trial_component_name, + sagemaker_session=sagemaker_session, + ) + return trial_component + + @classmethod + def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None): + """Create a trial component and return a `_TrialComponent` object representing it. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + """ + return super(_TrialComponent, cls)._construct( + cls._boto_create_method, + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def list( + cls, + source_arn=None, + created_before=None, + created_after=None, + sort_by=None, + sort_order=None, + sagemaker_session=None, + trial_name=None, + experiment_name=None, + max_results=None, + next_token=None, + ): + """Return a list of trial component summaries. + + Args: + source_arn (str): A SageMaker Training or Processing Job ARN (default: None). + created_before (datetime.datetime): Return trial components created before this instant + (default: None). + created_after (datetime.datetime): Return trial components created after this instant + (default: None). + sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime' + (default: None). + sort_order (str): One of 'Ascending', or 'Descending' (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + trial_name (str): If provided only trial components related to the trial are returned + (default: None). + experiment_name (str): If provided only trial components related to the experiment are + returned (default: None). + max_results (int): maximum number of trial components to retrieve (default: None). + next_token (str): token for next page of results (default: None). + Returns: + collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator + over `TrialComponentSummary` objects. + """ + return super(_TrialComponent, cls)._list( + "list_trial_components", + _api_types.TrialComponentSummary.from_boto, + "TrialComponentSummaries", + source_arn=source_arn, + created_before=created_before, + created_after=created_after, + sort_by=sort_by, + sort_order=sort_order, + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + max_results=max_results, + next_token=next_token, + ) + + @classmethod + def search( + cls, + search_expression=None, + sort_by=None, + sort_order=None, + max_results=None, + sagemaker_session=None, + ): + """Search Experiment Trail Component. + + Returns SearchResults in the account matching the search criteria. + + Args: + search_expression: (SearchExpression): A Boolean conditional statement (default: None). + Resource objects must satisfy this condition to be included in search results. + You must provide at least one subexpression, filter, or nested filter. + sort_by (str): The name of the resource property used to sort the SearchResults + (default: None). + sort_order (str): How SearchResults are ordered. Valid values are Ascending or + Descending (default: None). + max_results (int): The maximum number of results to return in a SearchResponse + (default: None). + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + collections.Iterator[SearchResult] : An iterator over search results matching the + search criteria. + """ + return super(_TrialComponent, cls)._search( + search_resource="ExperimentTrialComponent", + search_item_factory=TrialComponentSearchResult.from_boto, + search_expression=None if search_expression is None else search_expression.to_boto(), + sort_by=sort_by, + sort_order=sort_order, + max_results=max_results, + sagemaker_session=sagemaker_session, + ) + + @classmethod + def _load_or_create( + cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None + ): + """Load a trial component by name and create a new one if it does not exist. + + Args: + trial_component_name (str): The name of the trial component. + display_name (str): Display name of the trial component used by Studio (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + tags (List[Dict[str, str]]): Tags to add to the trial component (default: None). + This is used only when the given `trial_component_name` does not + exist and a new trial component has to be created. + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. If not specified, one is created using the + default AWS configuration chain. + + Returns: + experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object. + bool: A boolean variable indicating whether the trail component already exists + """ + sagemaker_client = sagemaker_session.sagemaker_client + is_existed = False + try: + run_tc = _TrialComponent.load(trial_component_name, sagemaker_session) + is_existed = True + except sagemaker_client.exceptions.ResourceNotFound: + run_tc = _TrialComponent.create( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return run_tc, is_existed diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py index 28732b0174..7c833a468e 100644 --- a/src/sagemaker/lineage/_utils.py +++ b/src/sagemaker/lineage/_utils.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. """SageMaker lineage utility methods.""" from __future__ import absolute_import -from importlib import import_module from sagemaker.lineage import association @@ -38,22 +37,6 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None) curr_association.delete() -def get_module(module_name): - """Import a module. - - Args: - module_name (str): name of the module to import. - - Returns: - [obj]: The imported module. - Raises exceptions when the module name is not found - """ - try: - return import_module(module_name) - except ImportError: - raise Exception("Cannot import module {}, please try again.".format(module_name)) - - def get_resource_name_from_arn(arn): """Extract the resource name from an ARN string. diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index 3921562beb..718344095a 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -29,8 +29,9 @@ LineageEntityEnum, LineageQueryDirectionEnum, ) -from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn +from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn from sagemaker.lineage.association import Association +from sagemaker.utils import get_module LOGGER = logging.getLogger("sagemaker") diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 01d4361197..af52da6288 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -33,7 +33,12 @@ from sagemaker.job import _Job from sagemaker.local import LocalSession from sagemaker.network import NetworkConfig -from sagemaker.utils import base_name_from_image, get_config_value, name_from_base +from sagemaker.utils import ( + base_name_from_image, + get_config_value, + name_from_base, + check_and_get_run_experiment_config, +) from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.functions import Join @@ -203,6 +208,7 @@ def run( outputs=outputs, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, @@ -605,6 +611,7 @@ def run( kms_key=kms_key, ) + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_job = ProcessingJob.start_new( processor=self, inputs=normalized_inputs, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 72df570496..ce6a3b99cd 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -89,6 +89,7 @@ def __init__( sagemaker_featurestore_runtime_client=None, default_bucket=None, settings=SessionSettings(), + sagemaker_metrics_client=None, ): """Initialize a SageMaker ``Session``. @@ -116,6 +117,10 @@ def __init__( Example: "sagemaker-my-custom-bucket". settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional parameters to apply to the session. + sagemaker_metrics_client (boto3.SageMakerMetrics.Client): + Client which makes SageMaker Metrics related calls to Amazon SageMaker + (default: None). If not provided, one will be created using + this instance's ``boto_session``. """ self._default_bucket = None self._default_bucket_name_override = default_bucket @@ -130,6 +135,7 @@ def __init__( sagemaker_client=sagemaker_client, sagemaker_runtime_client=sagemaker_runtime_client, sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client, + sagemaker_metrics_client=sagemaker_metrics_client, ) def _initialize( @@ -138,6 +144,7 @@ def _initialize( sagemaker_client, sagemaker_runtime_client, sagemaker_featurestore_runtime_client, + sagemaker_metrics_client, ): """Initialize this SageMaker Session. @@ -172,6 +179,12 @@ def _initialize( "sagemaker-featurestore-runtime" ) + if sagemaker_metrics_client: + self.sagemaker_metrics_client = sagemaker_metrics_client + else: + self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") + prepend_user_agent(self.sagemaker_metrics_client) + self.local_mode = False @property @@ -548,8 +561,8 @@ def train( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -558,6 +571,7 @@ def train( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries @@ -703,8 +717,8 @@ def _get_train_request( # noqa: C901 checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). experiment_config (dict[str, str]): Experiment management configuration. - Optionally, the dict can contain three keys: - 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'. + Optionally, the dict can contain four keys: + 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'. The behavior of setting these keys is as follows: * If `ExperimentName` is supplied but `TrialName` is not a Trial will be automatically created and the job's Trial Component associated with the Trial. @@ -713,6 +727,7 @@ def _get_train_request( # noqa: C901 * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + * `RunName` is used to record an experiment run. enable_sagemaker_metrics (bool): enable SageMaker Metrics Time Series. For more information see: https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 97278abdd0..40ed143ebc 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -27,7 +27,11 @@ from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables -from sagemaker.utils import base_name_from_image, name_from_base +from sagemaker.utils import ( + base_name_from_image, + name_from_base, + check_and_get_run_experiment_config, +) class Transformer(object): @@ -251,6 +255,7 @@ def transform( ) self._reset_output_path = True + experiment_config = check_and_get_run_experiment_config(experiment_config) self.latest_transform_job = _TransformJob.start_new( self, data, diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py new file mode 100644 index 0000000000..5b2aaf3226 --- /dev/null +++ b/src/sagemaker/utilities/search_expression.py @@ -0,0 +1,133 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Simplify Search Expression by provide a simplified DSL""" +from __future__ import absolute_import + +from enum import Enum, unique + +from sagemaker.apiutils._base_types import ApiObject + + +# TODO: we should update the lineage to use search expressions +# defined here in a separate change +@unique +class Operator(Enum): + """Search operators""" + + EQUALS = "Equals" + NOT_EQUALS = "NotEquals" + GREATER_THAN = "GreaterThan" + GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo" + LESS_THAN = "LessThan" + LESS_THAN_OR_EQUAL = "LessThanOrEqualTo" + CONTAINS = "Contains" + EXISTS = "Exists" + NOT_EXISTS = "NotExists" + + +@unique +class BooleanOperator(Enum): + """Boolean search operation enum""" + + AND = "And" + OR = "Or" + + +class SearchObject(ApiObject): + """Search Object""" + + def to_boto(self): + """Convert a search object to boto""" + return ApiObject.to_boto(self) + + +class Filter(SearchObject): + """A Python class represent a Search Filter object.""" + + name = None + operator = None + value = None + + def __init__(self, name, operator=None, value=None, **kwargs): + """Construct a Filter object + + Args: + name (str): filter field name + operator (Operator): one of Operator enum + value (str): value of the field + """ + super().__init__(**kwargs) + self.name = name + self.operator = None if operator is None else operator.value + self.value = value + + +class NestedFilter(SearchObject): + """A Python class represent a Nested Filter object.""" + + nested_property_name = None + filters = None + + def __init__(self, property_name, filters, **kwargs): + """Construct a Nested Filter object + + Args: + property_name (str): nested property name + filters (List[Filter]): list of Filter objects + """ + super().__init__(**kwargs) + self.nested_property_name = property_name + self.filters = list(map(lambda x: x.to_boto(), filters)) + + +class SearchExpression(SearchObject): + """A Python class representation of a Search Expression object. + + A sample search expression defined in here: + https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search + """ + + filters = None + nested_filters = None + operator = None + sub_expressions = None + + def __init__( + self, + filters=None, + nested_filters=None, + sub_expressions=None, + boolean_operator=BooleanOperator.AND, + **kwargs + ): + """Construct a Search Expression object + + Args: + filters (List[Filter]): list of Filter objects + nested_filters (List[NestedFilter]): list of Nested Filters objects + sub_expressions (List[SearchExpression]): list of Search Expression objects + boolean_operator (BooleanOperator): one of the boolean operator enums + """ + super().__init__(**kwargs) + if filters is None and nested_filters is None and sub_expressions is None: + raise ValueError( + "You must specify at least one subexpression, filter, or nested filter" + ) + self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters)) + self.nested_filters = ( + None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters)) + ) + self.sub_expressions = ( + None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions)) + ) + self.operator = boolean_operator.value diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e668b2a8ed..9d28e3bf4e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -29,6 +29,7 @@ from datetime import datetime from typing import Optional +from importlib import import_module import botocore from six.moves.urllib import parse @@ -590,6 +591,27 @@ def retries( ) +def retry_with_backoff(callable_func, num_attempts=8): + """Retry with backoff until maximum attempts are reached + + Args: + callable_func (callable): The callable function to retry. + num_attempts (int): The maximum number of attempts to retry. + """ + if num_attempts < 1: + raise ValueError( + "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts) + ) + for i in range(num_attempts): + try: + return callable_func() + except Exception as ex: # pylint: disable=broad-except + if i == num_attempts - 1: + raise ex + logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex)) + time.sleep(2**i) + + def _botocore_resolver(): """Get the DNS suffix for the given region. @@ -874,3 +896,47 @@ def _start_waiting(waiting_time: int): print(progress, end="\r") time.sleep(interval) print(len(progress) * " ", end="\r") + + +def get_module(module_name): + """Import a module. + + Args: + module_name (str): name of the module to import. + + Returns: + object: The imported module. + + Raises: + Exception: when the module name is not found + """ + try: + return import_module(module_name) + except ImportError: + raise Exception("Cannot import module {}, please try again.".format(module_name)) + + +def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict: + """Check user input experiment_config or get it from the current Run object if exists. + + Args: + experiment_config (dict): The experiment_config supplied by the user. + + Returns: + dict: Return the user supplied experiment_config if it is not None. + Otherwise fetch the experiment_config from the current Run object if exists. + """ + from sagemaker.experiments._run_context import _RunContext + + run_obj = _RunContext.get_current_run() + if experiment_config: + if run_obj: + logger.warning( + "The function is invoked within an Experiment Run context " + "but another experiment_config (%s) was supplied, so " + "ignoring the experiment_config fetched from the Run object.", + experiment_config, + ) + return experiment_config + + return run_obj.experiment_config if run_obj else None diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py new file mode 100644 index 0000000000..cdb9a7b8c6 --- /dev/null +++ b/tests/data/experiment/inference.py @@ -0,0 +1,85 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. +import logging +import os +import pickle as pkl + +import boto3 +import numpy as np +import sagemaker_xgboost_container.encoder as xgb_encoders + +sdk_name = "sagemaker-dev-1.0.tar.gz" +code_dir = "/opt/ml/code" + +sdk_file = f"{code_dir}/{sdk_name}" +os.system(f"pip install {sdk_file}") + +from sagemaker.session import Session +from sagemaker.experiments import load_run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +def model_fn(model_dir): + """ + Deserialize and return fitted model. + """ + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) + + model_file = "xgboost-model" + booster = pkl.load(open(os.path.join(model_dir, model_file), "rb")) + return booster + + +def input_fn(request_body, request_content_type): + """ + The SageMaker XGBoost model server receives the request data body and the content type, + and invokes the `input_fn`. + Return a DMatrix (an object that can be passed to predict_fn). + """ + if request_content_type == "text/libsvm": + return xgb_encoders.libsvm_to_dmatrix(request_body) + else: + raise ValueError("Content type {} is not supported.".format(request_content_type)) + + +def predict_fn(input_data, model): + """ + SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`. + Return a two-dimensional NumPy array where the first columns are predictions + and the remaining columns are the feature contributions (SHAP values) for that prediction. + """ + prediction = model.predict(input_data) + feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False) + output = np.hstack((prediction[:, np.newaxis], feature_contribs)) + return output + + +def output_fn(predictions, content_type): + """ + After invoking predict_fn, the model server invokes `output_fn`. + """ + if content_type == "text/csv" or content_type == "application/json": + return ",".join(str(x) for x in predictions[0]) + else: + raise ValueError("Content type {} is not supported.".format(content_type)) diff --git a/tests/data/experiment/process_job_script_for_run_clz.py b/tests/data/experiment/process_job_script_for_run_clz.py new file mode 100644 index 0000000000..32fd0ab4f6 --- /dev/null +++ b/tests/data/experiment/process_job_script_for_run_clz.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker processing job""" +from __future__ import absolute_import + +import logging +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + + +from sagemaker import Session +from sagemaker.experiments import load_run + + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + + +with load_run(sagemaker_session=sagemaker_session) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameters({"p3": 3.0, "p4": 4.0}) + run.log_metric("test-job-load-log-metric", 0.1) diff --git a/tests/data/experiment/train_job_script_for_run_clz.py b/tests/data/experiment/train_job_script_for_run_clz.py new file mode 100644 index 0000000000..34c86e0993 --- /dev/null +++ b/tests/data/experiment/train_job_script_for_run_clz.py @@ -0,0 +1,71 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This script file runs on SageMaker training job""" +from __future__ import absolute_import + +import logging +import time +import os +import boto3 + +sdk_file = "sagemaker-dev-1.0.tar.gz" +os.system(f"pip install {sdk_file}") + +from sagemaker import Session +from sagemaker.experiments import load_run, Run + +boto_session = boto3.Session(region_name=os.environ["AWS_REGION"]) +sagemaker_session = Session(boto_session=boto_session) + +if os.environ["RUN_OPERATION"] == "init": + logging.info("Initializing a Run") + with Run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + logging.info(f"Run name: {run.run_name}") + logging.info(f"Experiment name: {run.experiment_name}") + logging.info(f"Trial component name: {run._trial_component.trial_component_name}") + run.log_parameter("p1", 1.0) + run.log_parameter("p2", 2) + + for i in range(2): + run.log_metric("A", i) + for i in range(2): + run.log_metric("B", i) + for i in range(2): + run.log_metric("C", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("D", i) + for i in range(2): + time.sleep(0.003) + run.log_metric("E", i) + time.sleep(15) + +else: + logging.info("Loading a Run") + logging.info("Invoking load_run with name arguments") + with load_run( + experiment_name=os.environ["EXPERIMENT_NAME"], + run_name=os.environ["RUN_NAME"], + sagemaker_session=sagemaker_session, + ) as run: + run.log_parameters({"p3": 3.0, "p4": 4}) + run.log_metric("test-job-load-log-metric", 0.1) + + if os.environ.get("CALL_RUN_LOAD_WITH_NO_NAME_ARGS", None) == "True": + logging.info("Invoking load_run without name arguments") + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameters({"p5": 5.0, "p6": 6}) diff --git a/tests/data/experiment/transform_job_materials/data.csv b/tests/data/experiment/transform_job_materials/data.csv new file mode 100644 index 0000000000..9f1b6c0bb0 --- /dev/null +++ b/tests/data/experiment/transform_job_materials/data.csv @@ -0,0 +1 @@ +-99 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067 \ No newline at end of file diff --git a/tests/data/experiment/transform_job_materials/xgb_model.tar.gz b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz new file mode 100644 index 0000000000000000000000000000000000000000..3969bede9e315f8f51d27f3df2de623e670459c6 GIT binary patch literal 35946 zcmV(%K;pk2iwFP!00000|Lncnj+{r6B-pdF*o%IOhOTA+W+6Pz(U-H}<=!vQXb_ZC zlGsC$`bny%dm0VwKGeM1Uap(Fd1O{G>s&nigAq$)Rvx}nei32rZf^E3zyA3C{l`y- z-{1dy`Sx$V%zsHz>b3q|^8c>>TSgtZ{+-lZavP&c|GOl))bTfem%h;PT>0UA{(t}SAO8I>|J#51 zzyA+?+i$Pm{rvXwFaPnUAOC#w_S2hpAOH5pfBkg%`oo9U|N6Io`QQJ`|M>s@!{7Yd z5C7-;cfY*(^@qRzw;$f>OYbf};Nh>A`ryq^ul{)b;q~u$;_}_=AKzZSy8M?v^!eW} z-+g-h_SHXqeE;s%NB#Zv+c*FH^`|%Q-~I6SKgbU+e)#3}o42p@wKwnnSzr43)vtg2 z`RdKj`eM=-_b&(GZ-2c0_43unH~;;T?|6p~eB8hM)B7LsOT{mre*19w`e*+3K~L}f z@2~V$_!J-gczOBr$-`HAw6EUim5HOn3wBKU5kI+0ef6Uq_rL4We0cry8$G#KKl%l= zOY!5U50{s(K7Dxo=H2mny!-Qa`{nK1%eSxa2A=vKA1;4>^V6sI`q3X=|M}^?J@x*l z%a0#lefrlw@UEXPe|i1q+fVv}ZC4NAxA!>a%YS+O6Q2F^4;&mm|LceM|LxUJ@8ACZ z`oCU&c=5-FKi^#*U)Kx&mmA)rC-P2D46oX;<6``A`O}+U-slzCoB!qY+mDwgyZg7R z+rks^`1;Ae|KroEpD*8iyu`cX+Fg6k>$iXW_L^7xr`JFIcKJ#_#fzk$`uO_yKlFb3 z@n7Fw{`2MAGta-%sqM_uCzqNR^3~68K3x99@7Hc&eT}#4FPE=B{rTbYm7dw3f4q73 zE05>$lO5Sb=CCig5nuoQ`)l4_o5}X3=l<@a zo%AQ%KlwkIm;~N}tZk@mD_I}Y_e%hbl^E$CtltGDkz##sCP{k!*{-oLxNcAo#~KV#%w{_^4S<8RpOFFfB5m%qL`{m$|G zl?!5DAuhJK0OL{?6M_ zpIJYgZ0*{J*~{v2W%~KeN4t;XR>zs(`YHP0=6|+RJg%qjfBYv5hJVJ+^eg!LkAL~| zwXV9CA3o^m>%RTNdT{^n!-wDXmj1^pANjXf8$bIfzfiqBjtBhlr`LLHf4u*Ayqn|) zyI%jSu)zl8Kgf&nqWq&i^?%{gT?10DKk32z_~$3wxc_zo!uUNtN2qb_E5PN&^SIp2 zkc~nmyDwk;^Ch81@)cM={N~*+dJp|Wdhx~o{G;y?7n=8pUn4xRE@~~kw)Qf;7#BI$ zEO;mvX{J?Wy_6U2qKz(N=+pBerQGZCca&LH?fMS6$WwHL;BmQFQpd=6d)KY#M!IcV zLlE0Sm7HvM>!Ryxg}WO=!;@kgYpu5|!ye0;qwCw{MZW0vA2wVsMC$DLXw!q%P0wok zE3K!+L(O^uZQ0Yhh0;g0r&k-ZRXe1i8=tB20I(B%xwbt01an?lKWN91q^wf>u~ZT1 zetcDr!%uu%w7zPsJaN5Zz50RFR9aeg#!G{>wQT1}TF=dow6>(%#y+J;#(}Jh(pwsS zATo3k67e18tnXd+mO2XGk@P4gzO8@63is*q*Y_XZ*cX1~_g7fRK7aP(<*x`7&b*-a zFQ1K3|Mv119b5nK>BFCwFP`$3gW8jar+VGB?GTTdCWPr?lO%O_u`K3v8u z(q4=g>xJm|QC>)UA>)Ot7g_Jo@*=kv2YoQ`xBukD2StB$2C_Gl@uIBbd+z$t!?F49 zCyf+;weMtJjDef!r5?dWnQg3M534?xde+DJLdWN9I>53%PHokdvpy~tC9N!T>SG-} zWvwpyP;WD&3v?i15T{&5NzL!;oM)46x7^a3a*DrY%{04ya_VRt*h{i|vHCsizsCRh zVw5~|u;@qQxpq-TAi(SV@EjCh5H~xVC^+M`= z&AVMEj!~yhf5Rq9{W@RhrNS!lqUxWff4cq|`e*8&h3`DskJwMPpK3o7Kez?-C)-cq z?SP-IKf3D=n08NZrmVZ^x@m0FH`Q%B$+l4OLEJ&}crTO`5A;Cmi(+pm^F>V$x|p7N6vGp0hm^`_J|}F+uCh$U{n1 zSAjO{hVeB;mzTL@`Ylu>4X!mg&Z5PM3J*#y#$2h{0#f3AI1qo7b)Cz(6hHo{d!NIn z?IvZ$sv)^Ff40x0W7MD4!IE_`^AVD4g!nP8lx$7Yjdwn&Th-NJ+LJ}su$hbRu~vPr zZ$lSW9Xevq>vyzznRa*k7yCxP*Pnczuv<6V-(7ELz1Qrf?s~&dA_TF&tG$-&owuV2 z--oA3o*%$q2a2SAyZrd^&1)MZ7x)%2x#_BJ|6cL{+!3FofjP6eWupLR;^ zI?(rLxcdK4h(kwxPma49Zs{h;VM`lWXn@9w&dOyt^2uFN?G6^_?0Ww? z)>|Xl2+?nIj?q`QPwJC3*;{bYFtmNVPif6Y(avm5u%zBhn_{<^ zY!nffG>4%_9qrTbu0E|BsVUpHTe_~MbJ|-JGq36MgMv!^v;1it+KFIp>6$$JR@BvY z20C9lqw*}JFTDWbOL88H;r;1s7}SBReoZ!^{Ix*B$K$#4pS=5t`{O`)3=UKfm~X~- zjt2R-#e;Fnntcrc>SCyCUZ0p6^-f;05OrZ@Q?xb*XOyh>{on*y^#;m5(djKabQQ9x zYxOzA=dGr{NY198g{8DMj|O#&W7#ACxb z01M_dKzQzdOz_fU8~V9ygCkzn*&^epPd+g*D01%oZbbM6kYBMM!+%73)qVzk6sVQ- zNX(~zuiT919QpBaj|cOXJLfFg({IsGP1i_5L6TK&1vgn5T5aPa;@~Sqg^MHY=;fFmwNUW#~! zO~ri#hU~f|Az5LFc8;~`U=*a|X!$|3+3Os{b4BCyT8Rl;ozpLGHV8QhLbE~6oADfw z+O%7=*uB*4@6T~t<8%Gw@2`dB&f+)(rYjKJYSZFW2V&iy1db?%0DMMrEaXA#;Dl9d znlsE+9W;qQSsZzF|GsSnot`!swgK^=w|D?sZJt@aL>K!dIXuhNHYT#cTDA}AU=SXq z9YVHv)#Kr|>~PxN0}?lO4wyrSb#@=d3fn4Acxbqn{Yb!|Sl#oCqHAVpK5~~v zYu(?Hs-8x5v^KJ?nVmCF(Xi65ebxc*$Wd2PjTF;#sI1_x`L^{*4+M1^eU#zPYEY_y zq8;i|Xl$q3kwnzFyMA{d?`+!ceB$pG@OHGte77(bgOXCkLp~ceQy4V$)h9=gf=F{c0^cnLPhM+8#)O)Y3$^f|=*2C1*bY&l{ zZ%o!Okea6gbvQ_^J|~x%keG8gx%B=mf$@+{FFM8a{ksmXnilgqb>(yw*X2AsenWu) z#Rk5r5&q26GvXRzr>cY3X4G1Bxz8RUX^1zUk`jC>K)f|A*zITF2LLZcitrRJK0HN; z`|SiV;x^kj@M?b0A)@C8YktK2c0N(K0=>=L95*UXip-s^#eoEUk&)wT6@Wc5c7ynH zuAKFrV~#OWjM>p8GJy8Nj7*RhK_xQgV6fY2Z|#}-&66)8FQW&NgTFjpMNddcl0tm^_Uz~`(DIn`QhNU5h)kP z%}fs1qTt)S6RPIBbByI7P*vv<(K)5~l%ExdRtHo%JbD+;e(OW2hTM>buKyk88Hr>kvMK5fd{tir35d#bb_`D#Z@H+Fb z^edU3>eX}5tKB-W{=($sSijcXJ^IzVR;M_t@pn}7yC0!mxT=#^Bwlr@8I=*ktOjCr zDpr{O9b(u=fZMwY*Y~VmKQqK7lr48d*=soVS}prpGIQftEPeV@9*Qo`0$RlS8AWEV z<)CV3!MHB2Yn9=70F5hqv%+yla%IROaYVqX@OsReC8bG3n-Kmv8qhfbF|kd)t&Z6| zhF1i-vSm9ll4}vBBPuJ-nSfMbtQ5I|UKQ#_b|-7#=rzG(PIKY_r$?VHInFjrtcUFt zdE}cES@*)t>U5|}_~Z<)QD?PF@I-w3Ma$oYaBnPv*T`#jPmbOT=UxN3*Gl2n@}nE% z;%3xM)yLh8=M$wwyA3PDGN5{`GLBO6E_Mg%UHfBmIn3-;?{AGg4YUIH} z{LF12VPDqrv_;L4zSLCHh9tDj0N)bBJa*o#tp%Q}?&NQPp>)cnx(o@{kV=RwtR zwIYs$(SrKDb4|iIUT3KEv0NR?n3wTpblT^MktHiLSv+CF^zUBe4Di z+OosST^C&83%v|oXZ$oo?^zxL_)>E&Li#x|aQ7jVl!>O1NA}f62L7!C#5Vz7jHX2Btmk5n8d2%_Puv?oe-+?-6L@Y;I7-E9@(!k1 z79S3Kk@_Ask(azH;Rru$iniIgdZKq+u|W*1Wj|K~;7@X7KD-QM!6y5Je z(d{898p-c92NSzGOL!d~2y|4^9CgKM%do31k@_ewZ&4Jvn8h4Au{lLjmp>g-gY=U; zFvU1GrYWLc1m;T@xJLL|zz=oc*>vXcXLV9wu6813mk4HYxYHLyogT{cHB_gSo*Iu9 z*)4~r$Tr2fb?`=g=L01m2dqU;QP(Fn0Tg5J_W9W(Br;)&!lb;%eM!rV=*g(;VzA;zdou<3Z8op*W#A=VZIBn1+e zCEa6>JWfj*X?tp>YUt7JSP~JRcwHNFiiW^723T*E(Kbr$;yC~XL>NhoTxwqAW0chw zxrbe63&~=~GjtN4{`RU-kIulpmg#DsAMkiWgY(q9kgBu)Y~8ay`pJdq>&bca?;s{A zLC9%WGVmUs=&g*q`7|5im=u%c9NGYDl$8ymk6OTwFpT;z~=7@Rv#f&gSSLs$ANjHBsUrFSXCsJg8AV=!I z8j+YFT$f2OA8eTDy0SJwSV!R)y^xLT2im=UU>VwDkQ8x`9?q@|q!j1aEJ%ysY3w3K zjr9Fhata3W3U-LDoOxzm87~&whh!uxvLAdZX~@(uKAKaZ%rTlNqecOvp)H>~Ekjb` zC0(-sF5!qYiGfr@xv&vTDO3PV@ega{ z*lN8q#F{$H`wp@CjRDiJ_xQh!pFy4CDdE-kS;cr@kcO|NSlvkMUyXQeeH34tz<0H# zk>s(PWuqo3s~-1w5n_0JC57Qr_d1D3Q`KmJ-FeDbT=(zAi92Xe@gjzA#uv%Nt`m<8 zt%SljZA5VcVH$JtzQ!7nCXp880o?I*gbvO=v=!bDtMV;8l)>$ISQ+f=)VQPKC%hxX z>{XXJcEDF`s}j_rJ3_p9XNcFv!$Ld_Ad16%d>u*0d}c6_$Kq6FNFh7g7>WZbFMt9U z=>#$8eG6*#KyE3fV6fRdv;tws!%4Il=-e{B(4lc9$TtjhhQb>U^OjD!nX%0>W{T&8 z^mSo)rBg=A!l^>RN{hI(fXiKY$eHbi2vLxv@gvA-wdK_ieMc$&FzfT#!w&QiuIfRg zPPewSSzI`W&q83|^wfEXqEl)clc&@bQAHg^;goN&>crN$IOUFZE3f(tG*D$fu9Z;W z@aOd%V`YD=5p)_n8oli3iuBM;W4F@CInTB6Eerhg?ubsK>2CcT_^7T|jUjNyewzo0 z3h4B5lZ^yrmrmxQXb3C*rGd#&JFn!_6@kU1sWX`@|1veImdbl*AZ}wRZ6zc2vOXot zA8LH5_XI=9ly$lGtkJ>8Ln;^jm9&8#zLZ^5^SFxso zqcy!TM{o2Z#UD3?c!kNutTW+asa;2mv^<#`tVlpRML^-M*w7~7y{9>n7yNa2mbTlk zd&V4Vi%qib7Mp|=e~%iSk$TD+Wm+e4Lov&vL$n>OgmYkmTqgl>KC2#T>OrouCS?mv z0q%OKAUb}v=N}FL>UA2{wZIP}=J{j4n<`pmv08N! zuoZ+{r^P6zibTjZHswp|N!?3FP9Hmdn`|Ti@+-Spv_!C1*cnfig7w8oUURVUmltc8!$T# zr&O>0Lo++DU2HiAc71eNnHfevykOS&-gm z2k=mIQI=ncG-mdUm0ncDZHvXiGIch|UX#?=xOjZ3aHJOStv76H(mk1_z*NBrHg#j% zH#Un5gWlslG_RhI#^!E-NQ@J`vGq|bw7-^^4oVOOOj*WcjNO{|xK@u}S{Li0s9t09 z#Y((YRh&eTJtkXT(?X8aSZ=nMB&UdB{c(E!tosSJVN`O$^@)*DwOwaMoP0 z(1~T2>XkBX+CxTg@9<$i%wVCtU`&YLEt;&yS{wJHN9Kng^fjlv0pq@mrj6oJrFYNT zPWi02!j*h-YW>cTu0C@}hnA9@MI8FitSEwQ-5RzmC;DR|Ry0oTmY`8Qm~9I+Cx`xd zD6sF2A9Nm_4(m}0$B2xAQdR)uAq3POl81>Mycq@r=vDIZs-ERv5(PmAxlJ(D=ehS1(j@^%{LQ$b{2d zoOLE&Gh$(lEmtrIHCSv+ch!o%eTR;iZW!_Dx|;pnRtR0JD3n(CCF%)?GagY`&9(^H zK9jJY45^u%dEN9R8>@91e7EJNmh_q+-Sbxyf;~E_dMK&xtHooCuliGoc4Zd}PNJ;S z5bMiCQ^V>(!}SxIh6CbVY&xnm^dW_~>q?q@kw6B5vD7FYWQcXi_w3QuKqFUSGD$$r zWeeRh3X^TSD!eFzEqLge_eHHa8!xlVSD+KHe5*Rd6?Fv14!O-X*7PY|A(oB}*{{t~ zz;7q3FDumZK9-JNOXh^yh@O@KGJJ79Ad>IdKL{qVYg2)|U340b0PL?PHo74ydSDgo zIe2@~gmM}Wod!c+GVbYPw`_!djp>5y@MHIXpffL|Y*1q5({&9pys@gR9L12m-6(#Q zEgl1iyaMa5U0DLN!<&~;NZWnV($O**_+L@Aq@h0d6=J@fnHRKqj`=SFfU3Y;X@T?@4;-{qR^K4O$()pnF#PYVH z8jUq-v{GxhP+NRRQ4V@34=!L)7K0hl5M<1#7v`B;V#ow7k;G3o4Uu-6jRrQA(T}dg zb1Yu>JkoGlW3$Xkk2pt>ZteWKruDx@B&f51_oy|bLV1}7f?(OzzM^x19r!KnGG4H% zJ0Nkbo$~>SGXkfLUmwlwmO@v$^_A}}kU1B;kS&Ocr?Ccew>U+-d77-jnWts+JcZSr ztkAqwX1w$p$YX#}P5r_e$^nT;1eYC5utBCom?G@L$sRWpHhMctuz9C9JGoce3_~w9 zc}vgaV-Hh$V@x4V6?8b8PXi2+3w>pyAWS_H3B}89{entuGIq!&ekjGOagAKV=U^7L z2WbGIRi9x4D4jYohKCRZ%0Ii_A%iN0x*S5 zvS(te+|t{{)q%3{Yf=Mtfv#fn z4=cJ~Lgf?DBO}!h6N-2Qpm1Oc`|YOA`sJ!cLmNi+`0^Qj>bKd!9slNW#$)+U~jN4*~ zgWRn>^ttHdV|#UOi5zrY%&s^zF&s)}g++sUW0WC$&S)L?EKVP?dtGLz;fWU{${6hB z!Mwq8eMdReCq!r}=;-sZy?!(=PVfoh3#{4;CIV6d5Lqv_iYtb&q<}Zs=w20*aMl@u zQn?*=ZyMn*zHMG0IFIU){u$iv1ZWy|*u}~y<@tkxhXex8R3JGY3fuww4|2DZ(dT-V zkL}aB1Lk8@3Or5L_f{BM*J=_6%)yG}1HZ1_ZlJwrAp7N+f=*UYF6~h~0dYUYIZi3g zWb5E6%;4C8{aBC0rSmBRH@sP< zw9Sty7hCn0t(Fc+`7H6ZW8CI#l!MwEn~FtFRbh9Gb)M^NxVcoxOAEzI-wx)jJ(JWO z&-0UPZP5Fr%WK;Aex!g=NK=S^9pj55H3%t>bfz$Rn4S47KHCiHYZ(2 zyzyj#3YklE{wXMP&XS;m0)=~fe_i5zq#F1@#o30{!J755M3~=2SE_((XgKc#xmn?B zpuv`!s}tcnC_PO_t-6kmTuFR`Twvmu#zH;2hc0NXcd;VOk_wRrk5t#hLT*zA{gv2I z7LYuf+R+_IsM??!MX6kUOY4~(HeD#9;O?jD*f)dy?V{mp4Tl5X7*&pJpU{5j^}}S}^M*kq)-1o)yJL z6+{8Ie{1ahSSbFa9-HrgL3JZepO4XrI*ZfKCj8qovqC;=y4S&K0-v5#fzYk57?7Dp zL$A%df(}5}BeoXao-N_8f&Ui!%*kj=Hd)sK5qYBGVR39{Z^{6f-j|R9&2SfZLLqcz zDoxF{lDu>p1<<*Fu(l!X7(2A0Q%v#&1C-FQNYT=5P6iKp9I1Lj2C&Pj2tY(Sb5h+j zLmYgu>d!+sb^NDdMHQgnQ8)`YgU~N4bBz>+Jj}Q z9#I)+R}=VX>j*G!>ZpOpeN{@krzMy{H6%_CxJhRM|Jp;2;K4+D|2Wmgsu!P^9^wd= zt_367*~gY;9c#)1=hx3O#>Vq zpaJ|%5}(@|%|j0TwsM}d1E?Uzlpdefx0tAoM>%7K`he9B%ZR=ly@raA7+_&rvqbR) zA^=|dfO7?-oQ|qS?opMPQ_nJ*e=3@b?M68hk5r34LOCGhvYJIH)#8vrjs8JOQ#LZx zqr!%7LRmT+ueiV#NDQ9!&rUL8_-EOtQON9x z296q?`1Q1JpKY0Z$yXKWJ0Ble`6|bByyx&KxV6P)ohsyMEH<-9)EXiTDrN7*HQf6y zAQX1VEaw{MQ0?*>!SPX!nt{oSd1d+;^kdtDM7fYttFpF1)2gnOCDvTAXtFfg95G}j zz;`rppla5bv|cF2d9rVyO3)bDtfA1ClR~YlEW&2p8I+Md-(2-w;w(sAcPEb_FJluj zWLdJ3hZh;RxpA_UENp_q(IHvLqmxzE3bli-z&zwcAX_dTy^78SQDne`&iuW?j~aE0 z-rzApw${{3QR0kCNhC-~Sxn&^6t65qc9=vV&_73y42w?BmBM8@Z4D2iCKi<`OoRx8 z5mJ%+@- zX7|jHK=(n8wSdA}P027=iNIfHcZyjH+KcEu;LGAe?qOjOZG^iAU|4JP6oSidX#{|Kp*b#(xf)R`+{z`a9BgR zEAJS?giJbFr6rkSpj9`}2NF8A$QU%FSUs(&ux$8}QHw%BH-FP;D3ti5&RRT@N+4Q! zU3u!W)LN z6W78ABNUZE0dv_k3{>G0MZNUSh8CXpXPm?zA(y**EcR}4S;1|M*LHJ$jW!X9>c`k! zD&1O@d04stIT$}c3<{OSD4hUxFc<&pVbMY$9BhD_k`f<0ZulsSI}Ji|!FKh52Vd z{h-IT;ek0Oxl@WV>WFTB%w^kHLzhGI@2*=l2ULCHwPn~hF~6iG#ib{LXT|Y zGuyw}iY49N_2iCfMEbm^YvaV}8ij#-B z?52fEr?BDDX5~oz86M4y2t4vJiV<(NkGFeK#n%Dbq*0xiGOx#YC9~h%*-f@W*bb^6 zg0;lkVb;AZ@ze6n7AhJ0tA32e=RNMpcJX9o;D<7gbQ*#-mZ$r4@ohn9Y{(%v@n6u@ zkVz+0RGGF+g^r=xiGdZvd!zP*nin~X2q;&NZl&JDix9e6R~_mxXXyIRSR-jbAxFrDdiFv$F&m|eVaJi^$6H{b!DS@D%olGtSk$MXSTwrj z-n~`)fUBnKNcMbFg+p7&mgEL)H0rZ}8&xusKova^j0Dgf(uN zoqJs~S?A%E3We14=(h+Ql$x!rXsnOryo;e}D6+CY;)yVA<<&%pv+N@Zek} zrag&0b)T>X`Dk6%jev@}`p_HLd9Nt=w(bFlj?6I^a}`i3@|2Dp|C( z=S5MLQOBg<_z-w+_*$b#YDxB3q+*}RqX@9Fk=Th)N5?UQ*zF?SD|T7B68g~7&zIuYfrmWgrzX5Pu7ahT7;xyNX7HjaAE|Hs>4~wgP;Q% zK3yu-%Q?DUQyTh!A&oKGTw_hgt0cx(dRzb`tPG)(GtPBRQ|vjo+FA`4KafURVJmod z)pmj|g${aw$azC8sBEz^-CB)4lLBS+oy=Wnw}a>|+6tjMof0*0qi6t>O#AT5ZTyeH z=m=f%y#y}l!*p$Z%li3Eg%JorlS8-g9a)a{#%>Rx=orRzW81pdat0;I5??L$PJhP# zJSm|1rmg8`J3CnG>D|>>ybRJkrh(T3n%1YwJBKZNoZLHfG{H7$EERTw+5Ii8IcE_}iS+RM;#F0^0CFQY-7lAG=tZOfi7#YTQu zUXF79vS;+D677YYMitOg(%eaz$$4&z%(luX%{gw?jQ{j_>p^U`3FAtM>hO}jX3V$A z>aj1q>AK74F$XEAubg@E>{&W>kM+|fvB)b8!XA~gkIp0LvUZf8$EuwBKy->hSy69l zk#>Os4*L~!a# zTJ!qMEP3@&kA|nZlfn3AEcA|WWkB1#DKHuba)7QI`YaWFEws215EX-^u=LW^6k9Gj z8PmYKLyaIPceo~*AsgoGNMbDw$L4pS@cf9v%=5%MUm@y=h~`%F>BFr(AP%R_7H6J6 zW)VwYy2eO`8&dbYRf11E)gCKSnb97OzO3!_2TA76PVj^dGYKKeWc+WZifGu9Xgv+? zMaUj&NI$X@yJXE&{q4yZE4Fr6uM}}RIICcXgh;8~<6etF^w8XpJt?vZ0w$yf*+f25 z_CyL_0#P{72usZCqV_gF2NyRxvwp(Z%m zm1H7Bl|i05wZWDnO{=H9jcsdvrGFZ73UzzJus(CZ{xnig;aZqiX;E3X){o#r%(}Zt zJeRn|karjmM6h;rdue5O0iq9Im0jj_!s(Ws>57|E;o4guoe74?1HJggECU z#cYCe-J)B^tQAV&_>@U9Ggi9L#<4>zecs#;Qi#(NaeF1x%8d;h{aG-Wrtz~Ljncnq zhe+`b5wyZtiK93gHvjAC+$wTfa&@5HcDpg*cBU8;T+XdB;*+C+=kCzmaV+O#$JBIX z$$M)#>J#9gfb}IOl?QUYd?A<{cIPeWOHtn9wGmorJ3`)!s$)|~W|ubAAH7e-+qW$a zVSB=J7#ltaPQ%r__Iw)JhocdEc!L6|*fGU)B=GAi0fPHTGe$Xll6FNE>F^YCL8qTB~T&Es~0tBAsn1SiJ=|E6pAjMr7_%wZ0|}Dp(ANmcTPaz zCW{L?v$%&0B7{W@Isv&8J}IMMZyr03ryF8$nCy-(8eu1)}+k0x@S<-B*HjLr+ z+h87nhUq&t3VaW9EVNX6&PTV>M7%#-!B3m(8syE8uQh&R9E2WTZK_R%c(&XlC-o+b~Xz!sS+Z+J~g{4>vSjY;$nls^_}0>m$WzZ$jic*9Tyh z-MwtGTt|JRd{lK6qHtRQ zsjU8kt*w}~kp2(L@zI0jV_+Ga3=Fst$(Yz^*|iEZK>B^T^DnyOOW2j4u&S)UFn4A< z=MZ%5Fyzfnw}MFGs$?)6Yz3T9r1nZiOOr7LdZ(t19X}u)+&de-5@^|saoUPyUmr?P zt>E|31FGxWWv`Qng6GlKWz|Y>57IQ-$#tRcvJ*^-?en-%H>uOO19T)Ev+R{9Bq()z z4u71Yg!j1_`8`4MSId*1m22&m`OefAU(H&!#haqdLSuv;P%+ik?v2w_nlVKF(sgy& zm*7TK-HjC+lg;^v#oYp&I9+`5jy$?6ybkKTTJv?=4GeVaNX`Xvi>n&cr^XEZ?po+B zLF?GMZ-SApD%J(DF6dQ$R7R4clgQH>wI3w3+7MuCBswBxJ2UYuBIuPu;;Ta8FS5eZ zL{dp}LC-*navQ2)uJ(z*b+s726s(8sNdUZrrk{4nxjs8>6_9_yI(ck@?j03rn`|*p2$fI4J6w)r zTU+AQKNN0U0?C9)YxooE9F{UCXH>9?3#jxVt-2?A7puZKu^a5V#abk&!_S|lekeG; zlA^vhEWR~GeK#P+Ef&dXV3->)@im)o2xRB=yoJfOgcka zb9e1-no9J>!cj6v4y9jx*+k}!`)GZ;@CA_G6ntXyokTbBxJX=7(?;ccR+#2(JlNq! zE6gLK^0Rs{>ERP7kq1WM*SD^NY>KZE!zg7nj#n3q7R_@rqsW5IdO;Wpb_Z9o@1riB z^fde6$@l5{gF3745=Gk1L{!zqs1zpGAm)+VoUt->t906;(Z*m+06dd1@^ST=E&LqN z6x8R9#s=B$cF#SJ6HOw2wleoznTftP(EQ@4_AHvEV_f6E-P6svIU4mOZ9{6Abc{YN zhfM0jQ5=YZ4%v%v5Dp%9GPg3lYZQ92Mtm~%w=~}0*v0S!Wi`DIAC(&SMxs{~u1?{I zO7I>d<72qa5gjjewC5T_XF`yZ>NiOttFgkZ{Xx&ki6{pYBqT4QJ`g|YeT`MljR=Mj z_#~tc1k~7=9Cg9XCA?=6G(ZYX&u437q=`voQWovdOhS;;q@R^D!=WOaA~s09l|}ib z0+Ddz#(fuv@@5!$p1p8B_j=Df?jtA?ff4WFKYf~T_0{!r&HQ||#kAbPZa61knOEzD zb*{9?H(1qFYJI91Yk3kAl~+^=$Vuo^mQPhMcc0!`sxhO-BZ;%12a6u$>#~|obCd4- z+9r7g`evbEr3OphYwjp>&Q<3M#stW6**eCYD=9hkb`_afC8j~x@BXOo8B(1$smz>W zw|mG+exAnU4e*(;e7V|%S?*vJoJVJsZp0+arrl!~6u8_sdfL({fax0~ZE`Xn9Dv`J z{t1nP=}M7{CAq0#e2$kE%m;`KH+SJAD|v83-4)ZC)U8a3HldJKmvMsBz@n+%%X5Gk zkb<0mTzZ$YuLbz8)Jlsd?i<j!a6A2BBn{extu2S<}w$r0m(SzF>k2gyX8|iZml;hCc{$0GK5jvcleU8fSDVs z_+JD`okLE^AxiK5ulI8=w&xAUF+Kz3u$zq};L7|=^#0K_=cV%@?A^1Tyd4=wn=x{^ z(*76!XXn>avplpQ9auRan{BgGPwV#L3X*qAkl$TNFAu7uTRAsLL<|`O?h-;r1wJD| z;MVA)GvkE*%`DO7mQbFeETUq7JVFH3U_c@1H7PnUlx6quM8!5m&c5fSD8}<#ps6S3 z67|sBi`5eIG6x%^CoCd3To^`YMxhNT6aav}{Gg_>fX&C4#Wz~{va*BIgM^V&^%(2!(v|3&XS>XrFBRNyilZjl=t?Ax2& zu&Ey)48N_uKA@4*LP8xfW%3>5{!S^X&no%xfG#(p$aDou-joFV5)$Pw*c>~vwkzRl z3xw|%L#1gu1f_tTp`l&L{ydny@ACB80re>~Sh23Y6MIt-LL5P5zc^}%)DxY-Ji?15Nea>*~Bd~-HN|m^xI#?az9?TuFuDZ+O3a@hJ8)1+<7uU=J zd5^j=Jy@`t9xSlFW%>nFm&~4&ghCyg?3XU{P>GFIh8On`vSixVG3B%J$l{ zs(0hx-&daS2E;q5QV#X6VmFUN%aT2JI9Va_sJOZzQ0lZZ?L?%Wi0j9Scn&c^H>ykX zD7zuiTUFX_GYbPBsqQJXdllHhz|nKCaj6FQj8KI39)O0nbV5oSjh=ADh@Lp?1}7f)MRxe(9Q<_ zm*kR~7td-T^SQbPu|@$!T2beLVc830@J2~S>%@1U>ul-a`IVveUp(C|GEKU}zTX8w zaLyvSbhD_8YI9zpN?OJk-k*A~lK#junqP;4_2eJd6}y)pyG#_)a#R5nh36VqCM(9R zE*PHCL?0In2BIM>ng%SxY=Sr};OG+823avOBkx1^W@9$ULaZ?D`b3sgSRDZ&Z?SQ; zqi_@FSV~vGO+}7wY$VrBnn}U;4oE625C{~&i}%>BNDDW(gOdi;lk)qx2qLpc0hOo#VV69OGVlHdtKIAzylf|02(s= zgF;0O!_cQqg|MBDK8ja547q#>7n!F(cv;=7xW0|vqUoe;USYW|dwcve3lST18J>hs2};zc~(mHu1X$ z=2H(RDb?FCFz&ZU*1~Vhv({ZRMn*Mk?dy@A35_G-#<`{KU~xwOgv`Eywrwr5$)-lS zLMOb+oT6+2y%~_TH3$o^Id}!6t-FezIiqiE?B-`Zy12}*(4nrPRrl!L9a;^FtkDDp zeIhIuOqa;kzY{oUI=Th3Z!E~>AByFi<_Klj34V3tqL5tmxf@DYIeXT%K|aA$bnU}s zRE=&~q}DghUa6qhN1>-2q674EtXOy;u|k--bkWX{wbgwz=F)U90l%_83`rRsKRfZQ zh`<&&E8&{3>y2$#r$PzUhR3vjGlj960p4b9Sl<)H^K43YcuIG8de(#tfVJ#5V{Qhd zqz4YLf>*{VEwBv3Vn*|(`4jmz_>Q={j||3DTQVpJd^E&nedW61MdNfen zy5Zph0U^O?JAwpSA_w~OIFi_kO+k|8BY6y2(WW099^LR=NUpY3gznTMkx9T zpYLr!>SBKadWAcS38RKK*9_MajiD<0?IDWikTH)jr98q69RE)}8$8yA`GEp7C54nLifv)n=U-E<9uNS8u5%ro{Tu z#4xrwcHXY-Z6YCD)~04dkFoNQhm5|pFLiTiF)gUG>=+SlXP8)(Xm3}H6PduFbEY+) z7teLK!8hq-SjQ&QO`pE87A3u*48xo&{8rPTt!)0tB+vTeP1IDV#$CE!(=k ztg$+CeO)Gely^EPR?sv(i-mbRD!+7;-W;~n@E+#)edRX$>{E`K`^9DUkCZJ&mz9?^8svc$>os0lzPyZDp~yW;2%e{CbBNve#^p3Ug?+to6!E5|ZHZLIN~dglu*eFHC3M z@Vd_P`^J^r&)S^7N*#Cia2)T$uUp5x?gwY>jTXl5oYRm=;m{FaELeV9&vR^}gt7Bk zZEK6}N^wSgLjuu_koKG~o?fj{>@?(o2qvSrEE^MEadd#6*@)ib05m|$zn#KTzyK=G zHa$2RF(`=8C9+Mf)(Im){3EMZD^3^VPwPrnGK-~nzU_b-&azAO zdgo8H2Rew<)weSBwRH7rsWU0oJt^qR43b^Bk)~L4wMNH-QzT3C1^<9Jyg;r!Xo^Xv zifXt5#4O2W?Mn=l!b(qupWM0uq6!mQ-mRQzt`JpPR}}K3R;LT~190Xy(-KPF+_9|v zVY$T5jDoFL>&5a$W+BcjvgLo)1(h$YN7#Fc1F(u05RRV|qGjvX(74s5Lv1RhrB3l; z@7N-olN3geLW!AKA>Ihu%d>Q;9&4jxPnsFK=cV^MMgD7gUD=SO&HQ&$9iQJ(y6Uap zBekXXqcnClKm!{-V1WZ6dpjeg3nxZTPAw#?XWGdk9m!>PLkQdqsZfRvjh44G4-B~q zXB#sSicYdr&}m5cI#wt}Dsb-}Gak(H1Uyym^nw*)jTRB9IaBT570*Vr>MB0&E#le) z$EF(0O=P=z=9%h}500np6~-!lny5kL$(-21*dt?fZ!fBSN$5z83?7K#;0YdcC>kJZ zW~zYtf`-2mJVsDzBzkKK9OpxsQ}@vNAY3l1cJir>Q_@!Twv-CFQNzI~c;&ItwjN#X zF0E|CDwP$`JJ(z)8?M2LOcC}Zd?YRr9O;Dk9@&Fp1#WnR>pWaS!MWz16Wwdj+as0y zaWt7`nuMc@-DqE3VUb`LPIovg$;yCYKm>7vefT^S`|dg}MT&N!sY_FztX5;L&z4h~ zkTtSEDcd)i6_l-r$IMpvuL9aXS6J}9;r@pmK|~4nvx>V#AnB}0oZ~FFifg(xT}tnj zHGO(KVg4&%iF<&cfM(E*x)|HdA_!Qh&@8%u_uV~MX!^O*eCa9$!ME3}6P!q>6N)l< zp|0rwA}MTzgwD8&b5tbnGRxEe-d=B6&@l3%%c`Jj+tXl1r&73r*0P2(+6#ot*b9`G zhTGK_KeVS9c^W)%jmno7;zjpaCbPp6WO~mIn@P)+d@dm4T8+-yirqGwr1rVo7Xqj* zRa$66z4ep;@Rk1m0$vPHXQp2FV!$xKNYHOAJ<^xqmH@{c8T?*%W4fm2s?hXaDedQp ziFTR2L452OTC}650vDjUsz)B zQ*mkUR!nZ2>obF*?Q2&o!l5hNCbG`so>9?{{?)o)FOjO!0dlPQ6=oXSg~e7W%6v_4 z$raIG-ZtJju2&=94CGR%uzAvaEukki)??j($wt}gYy}_oa7`FfG89XYH}TXbBFPNq zPJ-7r$ZJYnG$`+e-PM0h&P?cIIfuu7sAdrae3B4!h=^wylUZ1frEMiQOR)1F8kIB zRRP$wv0QRINEBYld>N$MN-Y4=fVl%{o2x7PKs_X6yVVXNtw)cZ(8niVhEA=8vC)*5 zPuRgBI&h>@$Q2TWTQvEyTb{$|;ei#IqN7RLrVX^&Nvc&SuQN_i{t+^G=!a+og3FY& zhHsrqdmq2l@HgYMz3A6@OIKX$5HzVRf7I#NqI!aS#PE6<3 zmL#Uaq`~J%6w1}NF~YTj)Rm$|(9#l;OySt4(2WxXBwM>_(mA_oWNHi`Ifl?TDy$kE z?aGlNz8VrB6+_nIFPK3T>8#|U5$EKI@pdqWVo3+T3sDxuVO4in^1J!?TbPH4Wd)bgz2)~CA0>> z1DK&Ar3*i*puU@h^n=f)dP|$DuO*vw6tB~!F&I;~n#uhT9CdpH_v2YOASpodEkOMU z&L74KIm?UXU6AjT7q}ZFpU^`#m|Wwe1w@@3wS-K3jT>&X4Q_KiMY^YY2y^pRG*~XvDIo`=jIp-@hFQ!~%BGnhrPsEJ zkn!g_iw6jBC#&5G_|4he#s4dmfo%im#OS8pLh3>YvmDH7tmOdhsVxbrjmGGi-aJsK zQCPdEHc%7VH)GSYq#4sZG+3+sas~}L(yNxyFNOfKQB&JiBsh=c=8LcI6mV9x$Gc)g?1}Y8|RVrIwAAYY}!qd?b$~J}6QKz>|;wLAKdL|@t>7>4}qAY#l>4uDDu+I|_JFy=t{DnTf{rI*ppuaEiLU<9`p6NxAXLlS`VMzK z2d$!(9_$efon)H=u2yQ=@J0t{Gf@G=ml{a91m}Te;qIzOqYa8$qC|xbEn*0h{bf{T z)cU@|+cwu2!lW#;jaQ_%7xJ~4twJm7Dl~lm<^~##kN|EqyCke$iLLmg&$qVuWD2{A zXNv16ZYanoyvNW)kpUufmB-(}L8C^9qETb;XAcT{ZxTl96=Eevx5BapCFT0e&@#6Q zEQE6?^)^9w&6y<+i5lv`KTw>2XW==Np1f`Z^Mmf`2c1oG)Gpy{m_xik7u}M+N6n`P zvSjb1u6+IUbs)D}t*@?W3tU5PczhxYwR6u0d98qaMYk67iQ285x9mD4_9@lc82yKkCqg@k#@4SGE_dXi-oP(hL9ta0%QwWkpo=z zXG2s#RzOe0qA*1B8uK($Do)lA4W!9oJ=#&_#@&>*!QugOe66e%rOsFtI$UL<8VWh) z8e#-`2Pry@){fm+G!z{IU`5o=h{#5*p)^xz?E%PcXI2y3SpU}DA6xv-DW?;8?} zHL|Rwq&#+>up5!yR|yQUnmtibpa&FkUMoW?E>Z)%YY1OMG68*9Y&)?}O?S*Wm_kr0 zQ4nrMRP*!G<*!jI6^@W--4Z{+mTGrht+Opv?ltRb`Ps6M&2VaG#?pb|Fy=|JA@l&s z^jFP9p2M!iiPGJnBaDRXS7211_m%@gh57ypH9?~xs_abNTt`hS_?99gNpRT;dDTY@ zn&~D*d-4zoD7yv9Xf1Avb>wW9MJ2CEvHk$JmK{8Sx^2Hsdo2gRt_Nm5U5eD?8L;ld zdaQI_wu{KVw2264F`S7wnqJeOE-EQK)v z$`otse(9h#A$~osmSn4&bOCeQr{4utXu{|(D6~=5X(&^e^?`+m4ZN^kCKoOP+>#ih%uQI?o0BWq-<+cgLzI%SQu*6DyYPR z3k#C#JwZ`BM_;tAB`oh(xUc0Ba8lA>w?Ay$XeSK+_foDKCV?5GtJq_7qaLdXfB!1{ zebRcTy6<>!wUYihV0$zq8%NupnFhPmVUe6DRoqs9o@0X;!H%NyS(rd;yBIvn)>E%@ za0blB-M}KFvV5KvJpd52v*L~l`EIU17<;%*c7U+Vt|52Q`^V=nwB8IUTkvl}ZjsE~ z7V{_68>r)qJ`V8KDbSSBtV4{W^PI-T(*_F@_?IJ(i4ZXuCSm4v1KA*@{IS ze2M@i4c2`^P@+g@)E!ZE12xW`#OLQX#sLrQczR?C=oTB9Ju81kJiPZBjWV)c;z7*> zU36!oL!!>()5nxfzO}aNl(u3v(zCidB8udTOh64Qp^tCtEj$hv?Z|u8g_u>gu5<7z zcdx7WT}N+avfZ{y?#uHz z+f&Bokgba~w+!Nu6xJj9FMI7A4Gwt4dq)PsbS@uy$%U0SEFZApDp8GUMI*3XpM)pP@4&U!8q?-pa@&!nf__`C-6=!N9mg2AHBWMG^lDd)x~NOV1bt5+DhO}(s0 z57`UAOl|CrVL_8(iaZaB@*(O2=NSCr;ynw!g3(@pMA2}rB(kl$bG}Jo+6;;&!z)8B z7*Hg;GoOZ>?BuY_l#YQx2GfF+X5*(zLZA@^kPbRfr7Il>n7i!SOjy&y`>q3^b3By+ z!nE)#VUCwK+fi4)o?A%R^)1cUp}mD>4V8tXK^zP@OcI zRMu^=!sguKnx)le)RVgD@iFNsP+XE|lx_p|+tJ;EZB*Ab)S0SzoBDSVVm~^#zT-1D z#NH2VT)1^zzjvhmKogL&+?-<&7aDo{$9oL!5!u}nFKE;=60621$_ zbV<}4&=?;497`1FaTX)nIan#Wq`5o06M5#;sbKlKFs$*)*&yig=`2c_E&VCK2mBu_ zjd>>sir1+KQB{{7QwYba1^SCZwzg0tC?|7@$XLen7$tnj$ zIVv_~3+jce3u*E3h84ylc`LlN4u?Y~PB*1~9+^n{DVB4#Da!VqYqBdzW40MNU30FP z5_2*1ES9Ux!3U*>)T&$J(Jlqz7{27D6199?6J~4VTY!N-qv(TJF9ctqe5Z@vB>p^c zguc?|Y`R}XVMAzV>F7YNmTNW?OKR|+V7}E---iK<*&4V(F=^V3&$Ttl(XQDc*6kd9 zYZ=HX^3r`mkD8aGk7VQ}(~z+^B(7S9=Glf3WsKeYfCcfGixF6(PnOECy44V@)d)u{ zY(A?AS0^#EmtBX5v0`l@=?hmf6>;8zfUe&NW zO5}>kEN{_RJr(aaJ85haGi{{UtzM`eE|CL9V4dp`I!HIe_(T90VY_0|t%V22NokbJ zt%_UGR8oxFr?6s1-Ut^p_>j;Y7-rUZZhh;UyX(1{#($sQeRWnE&T}-&rC}`qSadPLhP?gZHa75-o=f^Tukaqw}dN5C$f<6a8V)0gIZeM5qT%(-T$pr$}rY3lMFa@Av zo_E@m4lnw^OMSBU59#yI0s{C$CUjVBlgwv{mt+`=j zuX=3?Q7c?YTjVVu$Q~ndyGUJ&JgBY|{SF#Vd!*Lej~NjoqLqB(``K&vJ;< zexdTfgp>lr+TxV0KYc`HA%)5wm}SM}(|zR;eJEi{2;vjp*AOh|Wnz8|x>teQR zZXR%5u$=i$JO=L1iT&2COAlE-@A4Qp#M*g(U4}=6pf@eUNN7Me#pk_ltxjRIGLx`j z65{{aajoK=u~@Fm8ugW91JtN+pX)EfKma8BG-!!6;6CY(MNwXJtc*Lgu(=8T9$tc( zrYoO3@||GhB|5h%kfRyi?7`H<1_m9wa)Pu@cI;k}L@AzHg&a;^4$2|(;_40GSmoeB z>5%)X8{C(5v0=!4^b4ML&z+~q&LdkQfP~ppoyZDYVa?{9+<=&+!k{j#Ugdqod_01hu2mxhCUhUg`Sc)su7bmEf`v`B_;V4 z9PV5u{gBrM5N^x}!%A4JqA5VE&NV=bkx>=R%A$&oLIC!V4-#z{xu_@n-*3qHu&pgdEG1QU5j8fhe(7> zbGt~wOMtm_`2233pw**GApSn1BGBn0Ir`QtO+9Si(5#@>Xq8P-G-{cyz6PZbrVn)b zHP6kV5}`YpA-pgZ8oJfm*);UDu-rp)7glMfv*f*};DA6sP^nLj=Leo%8N7GktVR&oe*@m`~bh&QMXM>nHs@~XwE7D;9&3qn`! z9xb36<4#L`icdz8OJj0nPAH&mvB(UcuN}Y_B*W-fs}Xb%+mj+|cD{m+S9EIH4X&e` zin%9Qbu!C!@@6qEPTypnBAc*)N-6CM^|zIDe%-p{6QV+Xe!3Fs-6lH)YgvTKtO-3) zTzVeIrgwbfbO_%fl zqi7%v>_LVGK*hO65TS(O67f*>R#R_jC_{$nS)x6mBU8--&}BfU{w>>-qzHRh_0nW# zA|pZ3XlrWFuNEU2dIz&KBbnd=im=7Gt!a(G{`OhMNY!uq1aXQIsS@}Llf8lt{60f= zpu?<0_{AU{V8479Nq}ek>%P>7?;~GuZiRig=KDRf2Ag%EOI}?s+LLk$x2)+pf;+Vv zr%pSmb~aW((p1P=)BC;CMF;G!!al48A8y9$b>aggR}55L?X%EbEXL&ofw;P^O1Dx$ z_HM8Wv9|75p|$S4PEaZ?o~pytCTqH=l*1j4*m~BZHas-f-D8+c}>HXrui&t*@ zU~tq41;+NIw8m|?ZVMZ3zHQ$=-CL`ro`?%>hK3Pp6aBuz1aVVBN^cGpcYdO_LMmPCCYy8UCRGFnxW@sMi*amOM$&G?k_#V|W9kBC_pBKKWTOm6wtvmg zc8KInWLwtJlSG$lioI)`w09fi;VY8&rkkCejaaPYwPbAOXc}gS#qbLa`ojwm4QsG) zZBNxZbsZMzVb&52djnlJ-Xa4PkEY|mIjYFAB&T-SWq|2JZK1NJ$fb=fS)-^V#qmM@ zDHeML7hd2->2k38464p(mE~$U!G${v0(lK&TG&h~ek(od*OG0%+~{Ww7{Q&#Dn+D@pHc2F3DWz3QzKo6Fzo%5oePe~>eL1uzk3QW-_oi+v=yl|KAd6O6%JXv1 ziR8_Zz7gHTrO_X?FjaZr1pIiz&BsBjcO$;JfLU3wh4f1$r4 zQ%7>~>7myWx>|0e*pu^|yUrJ2dV2A&pj#v8Ux=?H8g{1QiIxfMVd*Gd2`z}E8d^MT z&O!mahdXTn+qhA9ey%cHtQ}s=GPn05RM>M^GoiP5bP~vQRtlH%0#_QBip9yyzQQDG zbq#>TBk|I$&SEKC-P>pp-_-(MdIWJ}^j>Go2nu~%UEa#-7O;?|uF;)ms7$U^uOI`3 zh6cj%?Q4AxHiU`&y}JBomh`d)qk{muXsA7tPjdswqhr%)1KVl*G0Wd+|K0`?Q9$I*Qr z-UzTBU9ZJdRD2{Wbjq9Dpx=)MB;=~@Qdtk_HFHp=N}bS9FKa$49L z=)UPST#IJga)!H}2Cdk%TwW!wQM}d4I;AsHD`LV4sfeR5q&!-R*%Vq`ra@+Tam6>z zE3vH0(%W^03QJue5t-4?!S)Noc^-td;jj}-paWRXW1n!QZp1Z%($LxPjeAnwSW0?U zJaZ1CROT~_Q%$MQRGd;|Wt?dGMdAo=04}_J{csG2Q`hE3#%&t7uWVEKh&9-1 z3CXa|#Bo}JUOR&l8l-spoRMrb(4yoYz*Ctc3)8-0DrX+)Kt;4T#z7%K@~Ck-Nzq3u zq~^|v@T7Y(QctCXOj3e`^G0cUrC{;>(uGr2$vdT7u;0!?nt-tBmv^C^K5x5ijgBP; zfE(tI zD}{zk2XZKCb}I3~v1@usJ;6YXMau%qdnW}z@<^yZjZDJYF5-E@)y07wNFxMl(+DM= zHtZR-B{H{k4qTfK4J|c|sj6F5>YAH6iS*Dorw(E6$%+Mm@+Nc|I)Ue!b#0lMiM78Q zAVi^ide${L7@2d1!}<4`vAn6z?DH(;sa4GNjOAmpl(>CQy3I}{2zkF|W-0mWlZ;_* z%uH@O%P?BW5>M5ghwQRQ4<8#8#=FYS8kXIHX3{pEHiXrE@TK}r`z73XK}q#RaqH#O z4O|aq7@fkdfm-NbvsKl7vq4XOD;*~G$BBqhIJMrWDn7KW-fII%&8kdwEUoew-J zf;(x96fW7x&cy^JTJHTQj~_c}QuEw_XFXra70ID3JG7IgM+KzadG=!68B6rLsbLRn zd_bc?4sICQ!d4EE&jQ!Z?Nx-%Qu-KWyY;09O91FI8I9kNt@NIm|0^&7DO(P7g{27; z!>pIqYB8XZ`NUHMx^f|40^~}aZM{KIYpj8?(JWvp4+Bh`OvrCa`-!=Mq0_1k{qX-? zEvKy#G1E*v@rYq^PkBX*aJWcLr%7b!Db78UC@@K3Z<1%BvN?_*^wafQ)xfET=642L zB*$ffvuVhx(XPNwA3eXk%6f#}Xm8DQb*}tPIB*X{)%PL{JQYyio|Y`Ps%#(Rx&{OO zzM7dl#mkR^y=eL=dRuv+9f|wuSqc!FCN0w~res|nvMnv}qM5YHtn^RGDW@o`>vp6Z z-?tKK`^PM#hG%_J4iUx@V>=%re(b5QFhx;FKD`8ktP%9C&2*!evemdTFjuSV5ry@K zz|T8muWu-hKf(lMi(H?kEl;uK6CffcB!<)5i`;_e77_o6CrLS0mk*W~^tM?`Zf&T! zvyK}>3dI3T-IWz=Og6#Ris^K1QX|a7(+n^4y0sL+i>7VUhOfCxkgF6SN)8~okEp=b zuCZVYM;0mDcxo;-Pojmnk6m}IZI^kF(_L+_l@8)QY4jH8&Zzv21kTn1u)@-Cpy!TD zjRta(f`$bnW5Zu(^O_E`yiV(F{ND8M)OoM)rmdQI7M*-?ZA2o7VJywb1SC8g?Ad5V z9~U9C>Y2$h!qEG|JMoV^bd*NxE{(G4I++m7MpAd}72+P|I?EN{GWqHX5EXS<*V%Gh zV=i=0&vrhjBx;IIsjYC5%sTna8Rm>-%_!W#uq?0WL|1S>XC(CoJSIIzZ#sA=evE~` z!$3Bc=b`cX?)sciy6d*UBW#;x^vK;!6QJO&Ix}xhsAcJ#?XKP?#;#!m+X+iG(;vUn zp~zw5!-$yK^_LyHAgLL)o8V;i9l}jy>!R&NN4Sn$mL_yf~OZ` zNslW9jgg3@jsMB&u?AFmMnbuPm1ZaUy`!F|z4CRqk>yqe>Ca)&*&Lz+qRo@M$BKD7 zP+G>oB@l*g6hV(PbZnWFo@cPak)JtQ(JiC{3Wc#o(<89tA7{2XMR&E7e@+DyZT-#BMsR}FB|j`TiOk7hhkTu2`xF&a*G zPJ5#jhrIKchC_my`7MB;Ja<3R+?t#b10SohKL zT!#tg37ZkW3~FNQjk`H6P=Tmc_FWep%210jb)OnURF?nx7{W3#8b?tWq}2$N0xau? z{F^m^(E}wxQ7E zeqN1nDrL<&Thn|43bb-{OE&Cb?5eL0R)vU+6*YVtZ4puF3FJYLI-wmfTB8lzGAWVh z?uJyN(7t3f3l~|4pSmZi;7b$s#Gs%Kf31kkWI!r_xwf%f>C|8#atMtgIC^F1y}l6W zl@Vp)TtOs=3!P-{ex7BT27bhtUTfmvkJ1}SE>d8Z$$1l2Ua0hY)4;Cb!|M|Mn8d`^ zHY8wn25<=U1A}mVz^yF5{PsS?4@01j@9TMLAJDUa*8rxs zNTRQ!Qr1>NKkeIbfN<($e5#gOqi^TA%i3tTUBWJKlmJOX)WHJM2}{a)-*XCXwXJER z2Awgjd1as@FPIr@>&TwVy$cpbuMJ;x3$EHWVbBufBP zq{y#3uzU}t?c3XMZdVIAEowhdNOgnbgIwvs;ca_j!+ffuS~GiazC>wzT?}?==D|W{ z-<55{ysr8gPbIcZM>*TfsWeg(caOC@Ebp>(*1M2mQ`W(gwyz&QzPy@kj&M0TV$spI zqWRav>=yT$GaGkf7r+JzaXrnVu&FN|NeT9+wfuq}ML?~XbI{zm!&r^=C5e=+33V`R zxL%OA)6R1`1a3Ai8#SuLEXEHe``-wY!OM~~wx3?Ay5O~~BnAtCxw%%~F2 zV3uG5q^^s2-44`M7k;d^=Ky=9GvA~_k62o@*|^I$RGjfjIQS(4^GMR(^zK01kEwWn ztzh83Nngw2R&Suex8X7<$>RiZsK5Nh+}*aus@q$C}9Q9q0uQUf_6L-(Pt3rXmAg(@;gRW2Zs>q}M%o2&Ft^ z8`CX8)>hAiwPcC)B@;E-V7p>DcS99i=!p$PNs72&ofu|Up8m4&z6-}q?L6L_Z-L3Y}KP%qLfLdfu>xZYW+}IqwnkHV+)kpnsgt~kR zd(q9jp2*>x2i=&f7q;e>TaB8tspk&Kxz=nnHUx*tdhec+ps1<_AUJ0#kQ)JnIIPoV zxim?=mDlkE2~~ul)0=RbYetvu;j|!tYdzv9M$0Al?0vEj6$6ts#1@y-cdfXa86aqd z!W|{zu_V3OxfPO&`p5=k^fX<4nC?^FHB&Em32EEN4yu@1t9J!Jgv!*Q z5-<46vb@rEuwW{yK&8`C9$AQKZBrV6!D<>DHenZzTr%{Q$o%#VSj#GeqIN2Ros%9dnJ>j^IJjldZ^GZAU zSr&8lGKTm#^?u89vwtW%(*Y_qKrCo;X4EXzHa1{>*)7;tYc?)#&?SI>Ad`^YCwCRL zxx`oF6JLaEm;*yo)a&Jg*hT^0kEXTK+f{WxTJgx%TeuC#66-X>E|ewQTzP{X^k#T3 zLb|Fx;6aNIYl0Usn7mEa8Y`vs-d(b64>VJqbtS1zYAmy105*tRD7^!)ruk_{Gsruv zePDYIwgsRo>f(PG+8dP!3W2kC0@zj5dKn6CCHOr?X0f3zYm2Qd*}|2_^N`QHkr-Z@ zWB_I*hGMBMBDrBkCJ6->oU*~7A_U?2 zz2&Epp)LfV2c(~0<#RKbN9#Q)CmcG@CX}O*S=2m17Gi=l2}L7Du8WZOsx;eqU zB_ptZPfgUf6dfOKoW6@Bc&I5FA@VHwcweFDRemo*ImC^FmK};fr@4HOIE#LJay>aT z61Cb})oDB3dFYs2Z+DK8Fe!f+H0uJ4x;^d=VbpR8r85+Jl+Eg6_j+oLWK8w=AHK|o%3cuyg&Xe>XhNmD$j zC^tB3-cKQhO1p0}F==YtuIYZ2lT|S!OF;~iHVrt4=RA){RoH(Oia26Mge7dm04!7! z!OHN@asdX2k#N?V03$JWwJe>r=9yUBKbz1;srJ-+4ajoaZBUciGf2oUs7Yh##4Rbr z1DZ6&lZtYK`Q)_};^d+Trn)IGKq}T_xg>R)u^jAf`(aiCwaJsLK$fDA-H^Pu8ve-# z#bfg!^NCI3h{LHS;)X#Vk|J!@B7M1dOn@G7{J`_T9em)B*{$$X(c@{DdG$6j>qsod zG)m?*jtPGg2l#amH?}w0nxw^%-c(pOJDOcib<>vDu}R(l(H0+1PD61`uxRB0TyaC* zmZYGb%Y(&Ef-IKQL6)tdS992tKbAgRy!pK+6xpVG>y!zg{9VfT2!g3i95x*iEwocg z%6_Fy;L1cItwWqU#&2|*vD<;Q~%A+OVS&HlI$c~!T5prmrSV^#P2#@BHlz;#w&)?yhUKUv( z@^bce0c?3WbGX2$q#ZCnWdQ*2V{PPI!fg2t) zk~Yy<=uv?xdG$kUE7P@^R}D1%YMWS@9(uq@#N{`73FE1erMyVi%G|R zBGL5W>|>+#nx`?L0raYF{s>|*5>Pt`3Ru#EKVDFiw)s13KJ{ir07{L0c&RDT<6t{9 zlc$RQJ6uJHN=UVlxly5^uDCCVmq`w9*04hr)XkHA$;$1SE|j%51gH0giQU@Kf0jP; zgMI(ASJb6!Bv=qPe#{TLU|3&3KY?DVtJN2_BrI;Nq4G>?O&ua+o$j z>3q}dVM{fqI%^m56l-EwW@=z4Yi-o4^W^)Qde4A1eF%;!)Ko2l@j0}rzG+M0 zy4ol0s%}h^x_NUd$2(;0px)*t zoZxHj4S9J<4NMd`B3wKWxd#ZWUL=HucX8^fHzh7!Ud80Xqh0x5U+ z%P+X!u;e<)X;mBqIbd)f?{IGD5E1DqA82)X!8Oxf@n~|$HlyZ$9POHyq#j&JoNgNZ zJP)bxnDX!gRv!$;_Wy6QSnHEUM-OLVYFP-BKY<96>{L@2hW$-5B~h3v4z~oY5{~M@ zJ(W{FuZ0s!RWE0!)w-l}V-udaGMEGQ3`JHV2>}w9Qzy}@wY-@ibSWW{QNohs4fNnh z5Wt!~q{845B=6TI^EGT#tFqavZ8I5@otT45;QuZtIkbGiG3-MzT%Vl5U4!;F`QHsJ z*ZVdaTs=l{eVgCfo;1IDJmXUj0}@7O*mSN;m(CI$nU({)2iD27(r+C0oS|nNN#S88 z%3F>zr^O=PQrKwk5Vso#WpYO)*+P2*y$OjSA?%ZV!!5wD`^ zlar{IKo(0jyn_atLFr`uhHI(Z#}hAMIbvDrQG@Ek;i_bBbW4=%uHG!C9&j5*iPge3 ztNMrvxkJv?2jv@Ha?@xHu!Aolp^1iHE6fl<0VqZ*-Z*x5^SepzxxN3mq^@q(ZNqUp zC+#QXm30Nq+Qmq~FFUOGMqiC}Da2{Lf(9dbgaI+|R-M_lWDBdaubClH_fZTG-#-I= z9=7Q{(Mpdbr|m(R=2!PJ(}?}|-RPKDyqAsrC8Se=Z}u=aZh^p@C3{kd!Z}+HRb{o( zU~SZ$$F6J*PT3y2;C40x<;{(kI>_v=-zAfFkjt7y0g1H^ zZIbG@HVSD8O^uKHUlKtg-yW%=MmS2iMy@8#u1F_p-3R#wn?xNZkEK$m8u8f%c$}%BskR0+!#i zhZu8;TU95pM3*>Ql_Hz9xq_%9x2ZW~wXC(sTN&ENdd0OJml!5=G3pFSzLQ3a_Cb*2 zx?FyDN+zJBC&T+4Ud3P#=p0uLczIbk`r9vzKUTM2NP)HkM@k!Q_Dl#Vx~Z#0i0Uw6 zCGr-qmbLxm%GRJIv1I?fryUI6giRhENzgD}E~!@}pWQCeT2<469gd3O@bcH3)aKO% z8-RPu?lCu3pPY5$Lxld{b;Ng7nLaZiu`O}EOBY&LG`wGcvHIL2+Cxj4&VLE_GyHX}ahp2dvqo$lG0$hJt3?XE~Z z&b*^?N}sf}uhs8sv%Jl=Ig5x!$cE=O;x z0pNg{wMIrzaaI}2>7y7)ujC@h4el%C%}Xs&3`HwfWlQcSbx7pK=*Hrf$hVNNDS&iAQ=9W%i4VoqN2H=F z1{+CO;@;R(1*K{8S%gyw9wdT~vf>{+yv}S@vEHF=GBqioR0MIUn#6vYM)M z&;faTfT;D{H0wf&b;g(-3Dtuv+5=?f%-P~UiA3s*HY-;}I-US8JMs^ugez^jDBTT5iJxTruimK6yR~q|ZM0ulSE**!cDqhQLrQX8+~Ptu@FFS4McBHf37d0S z8=gW+QxX4118j18n^W2m;4xWaa3>yY6&fY%*Dl`Rh_lIoY__bV%^Z9=EAvb$voKG| z?A$@1;=QM?VOM7r+5oeXs!0wKj47^^kd+DJ#5I9m+5|Nz2RtRyjgF^H_*RhH{EoE z(3~)MMTtDg+tO18)=5s?RVC5$14|SiEG;nwnlBfz@ zK`tR6?Hr0ML9mP#kY!o9Eox1O;F}6k0Tta1w+0mLXXpyFsL`QRkI9u*q#e28KBSX& zt$jJ>d1H7R$R;`HC?*T}Q}nS3{MO}_SZ~9De7FvrR4OTE!N3eTFcK2wPQ#HRxVDJh z@bh%3WH8u60uGYW-2p^x4EGCAIqG4eiygp>F}xSir4VS_kwqmZS4KhAR?D=>5ggI_ z46S9@ak%^&Fj(@506X?W$LyCxbidqW#{>EUDx~LV9*RJ&k$(V2GD=Ze|1g zz%!z2iCrwU=U+YHk+V`X$GaIC1&Lt$P^g4M!{Cn-C|3F;|y z==)ao-z#fCcv`Rns#B94OU-}h&~BiD#@fVoCMecXKL(|UsaXP=3FF2Yl6epUb;OS> zC>SsBaNSb;m*lYQB*z(Xu?h5=8gE5X%>_I5Mw5U+Vmh4N(X3()h30CTjzfZD+hO1g zMF(w4Tx<+fDpw{tHC{_YEITzH$NY2^$`3$AyV@vNCykrohU&{?$lS=#rw@y4_`#pg zXayWr>@)81*1*86*IysflAwKbz4y@>7^kVhL1;S+6}#Q~z+mkOMr<~N1RSxFTW}@@ zMTnJ0a}NpV98DKvC~~*u&WEM7(z>0!N>DqD3=_b{-LPS@lk;-i+JKMirY4P?%ucMt zMjo=it&28cg0#6~As$d#4&k*?(sT-Hg(%PFy<}r#-&&GuO$A#UB5goLBKDvpA5N11 zWwUMO7a-PivZ-5DQmp0U(9$I-Qo13;AYAPpM2$66R9@kFsjH9IM<;!Ac_8qVT;-Xa zgDX0g?}*&Itz7x8$jxn8#Zs)gB?8b?qwhA{j~^KJ1O$gj1=p4L1C%bnSml;?HlM{&k?RUhs65Mm{CL(-^HK z5bxnY)Lf_$wkG>&N@+v$;fJT*#xf31T~ z;ZN<%Gl-5B@-jW9?UQYF$X(XcxIOsWFetL-!RC;^s2tS^pNcR<;`kW|(1yZO509Ab zG^xk2sxRI@kjiy#b;(sRp#<&E#&0${=hfFr2ZGZY^aG}l0b!B#FlQ#Xag8yFjR6@9 z=JQ0y3lHZl2T=Q}69OSkUmPk4f==;A*w9^e-W3>FjiUf;nw7$#PcEn}TfGZ2l|oTw zh$^-`91hN1MI&%%Qwo96D^DDw z55RtO1ctmzHCz@d?zE^G1qj6x%I0Ua%`cb6H@3=O<;wuVr&YO!SoEP;Rgx`ht z!t8;5>Ai(F!jz3NHf^|fm{IhjvRy+8V4~4h&8ZoLI?#mhTJ}IzhCjgeUk<;zo35Kf^48{1(=LaxKPAw> z@L(~i)`mxRPDsv6eWHKxnf0UWgb&*Rf0XCovG@U1)_caWukr=T?O_~ygS^_cpsjSk zzOr@3kzMs^Ft#P1^rEbYoZ*S1YuL&j7_+64t=kf^F<{tfX|a?TaN&6?w$}yLzWp9v z&8gZkCA(Lp)HEspY>`c_snlAHbsz_$fD@L&>6H+YiPfPef zJ@ls_=q?+?zRDK}w1+|L4T^0KWgVD)pnpyPU3GnvHW=9F@QOP;WTn7`LNuG`Oc>T@ zA&Nf3qbdQNv;DLkSq?#hc$20V3d_h2vrFZG;vESIOX*ej!CgoJBpCP`AK=XjM@u%} zXXw~&W$Q2+6mV-QT@$y|J@MBI1y=ImWLr5o(^aPci@sCVfKNRnV;aO+oNFO>3u>-B z5irR1RHMLn-7a4Wm9(960P#ZilT6UA8=?3)mPv8-&j3>DS8QNAZD3KokH~}!8gQ@$ z7NK~T{`oLEvX9mL*}Ch~EncZ1k7U>315UNKpP+)!d*l<~zL^;9G@~fYs7K%pSk{xg zO0qx;*6ffx#AH&GZcAvMVh2P`N6Tbs)Sh5 zmuw=+-hEmpW7C4BJcs#E*3^Svw!xH@SA8(o`fRQVDZ8~fLSHVR=Hb4qQZSa7c1bPM zYa;o$0}ty=^uE%}q<0p1FRVsY{Db6M$+=BwWfxYFoeu%89|ElZK;7TwdOh=N1902?I8H4~QI}H+}T)-8&!csB3c1IIP%s`Yp+u*dj zxe3$-w&x01BCzx`(vGD4qjuAP?Yj+1fXL48Tgj9!ltx^I+TYx7c#@6$t>M3;HaKu` zBhYkmGLY}0I)}XBy@8;oZ#Y12qLY&H_&7dg@h!CQkAH>0BiSN2b4Md4~QH> zpp_41d%r=xsnu;RDhLq)nM&}S$knT*%2Kp}9gtg%=|I_OO>dVhgmf1ssss28H_*&m z?Xxez;m!W4(r$e*u2xl_xM;He-ZAt6T?i-@R`wZ^M#Uxhh(YF*I!8@G9VhK&zfjHg z9Fu4lHMKh}lWb4?Q>E0_+}U9eDCNg?y?{KWQ0JJYfjBYP9bZbzVH2&* z(H-=*B@IvbQT&oVPaeF;jEgKBFBMB^9cK}#d2FptY+V*38tjM?d1puLX*-xh~w;FkU_*EK)vnHbQ}>A+`pxIkOi!T>?a(+YAA#Oh}s^jTr#@XrC?!HC`S$ zbSZ6;ksQpU)RI2{U z+4K}PpFyCvh`%+dfQz@*)Q&?ln>J_Edy-%kF||k)xX(<1sF>Ox)f0U@Nw`fnT@j%z z(_jR#u#e`9t&vumK6QYEQrnYr^OeW)9Jfk(Y>J(59+5@9%r5{V5ZQ5Uxz_o@pvz7} z9EKws^rXXNoM7~(oVn3`3rPOFd$;1Jw-C#KSkoBOwJfC4UeypJIIril83>rh?3+R{ zbTNBfr7mLzY;Z4%rN0|fgqd8`yh!%ETk%Osrs)cjWDu4sePZWG5S|XF0@YU3rv{>) zd{t;OtxhyH7y(&T4gpMn@iU9+4kZk}2?gHHeuku7a-kn|&cM(4yHl#9_>@%bBlGX8 zN<1bPtkS8C3k3i>*JGjsIj`Pq^%XhEmR5Ptpv!I9i0X$Pbv#nkKQQ#TEeHA@lAOnq zp9iH!?{^O@_`PLNcM^Cq0n!&$cUsEJ&bE^zN6(PM&_*<76UNIr57kk^S!0-EsiWsdpzL%{a^NW%q_#wE)ONbx zG2v=Ep%TyXPNq%z;@VE?i^FM2Qwv&lSdzQurI2r0K|9dnuY{(OTF$!gA5Q8NWab`1 zrA)+Nqg16X_Tm>^&3v-SLvtpEE(>X6BXbNT0>qF`7qKK~_p5lz8!!}`c8$4AyrRBi ze&n+yDSP4n2;>0zM8h=-0jC8<+DG-!>1Jk|Aixplc}hNHkpQI{f0o|@jY2>G&=-@T6LfGXo`#3*oQgZ7qB~Z0@r)V`d68Wd&mvdiUH@M)=b;+MEeZPfqnHJX{7H z|H)JIt!yvssEUwr$AufF~D@8to$ zKAu-I-8aAaMxO*bi_i0$ufKTn=E{}GqkZ$`f6Fue?ti}0Yq#h6_N(9i`Q{Gv{coTD z8GfGb{>EYx@Ex^xZ%IGJpPa|M4I6Q~vsQzyIAI@BZbN 0 + assert list(filter(lambda x: x.metric_name == "test-x-step", metrics)) + assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics)) + + # metrics -> eureka propagation + retry_with_backoff(verify_metrics) diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..713a6a3792 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_run.py @@ -0,0 +1,662 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import os + +import pytest + +from tests.integ.sagemaker.experiments.conftest import TAGS +from sagemaker.experiments._api_types import _TrialComponentStatusType +from sagemaker.experiments._utils import is_run_trial_component +from sagemaker.processing import FrameworkProcessor +from sagemaker.pytorch import PyTorch +from sagemaker.s3 import S3Uploader +from sagemaker.xgboost import XGBoostModel +from tests.integ import DATA_DIR +from sagemaker.experiments._metrics import BATCH_SIZE +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.sklearn import SKLearn +from sagemaker.utils import retry_with_backoff, unique_name_from_base +from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources +from sagemaker.experiments.run import ( + RUN_NAME_BASE, + DELIMITER, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX + + +# when running integration tests locally modify this to your test account's execution role +EXECUTION_ROLE = "SageMakerRole" + + +@pytest.fixture +def artifact_file_path(tempdir): + file_contents = "test artifact file" + file_path = os.path.join(tempdir, "artifact_file.txt") + with open(file_path, "w") as foo_file: + foo_file.write(file_contents) + return file_path + + +artifact_name = unique_name_from_base("Test-Artifact") +file_artifact_name = f"File-Artifact-{name()}" +metric_name = "Test-Local-Init-Log-Metric" + + +def test_local_run_with_load(sagemaker_session, artifact_file_path): + exp_name = f"My-Local-Exp-{name()}" + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + # Run name is not provided, will create a new TC + with Run(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1: + run1_name = run1.run_name + assert RUN_NAME_BASE in run1_name + _local_run_log_behaviors( + artifact_file_path=artifact_file_path, + sagemaker_session=sagemaker_session, + ) + + def verify_load_run(): + with load_run( + experiment_name=exp_name, + run_name=run1_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.run_name == run1_name + assert ( + run2._trial_component.trial_component_name + == f"{run2.experiment_name}{DELIMITER}{run1_name}" + ) + _check_run_from_local_end_result( + sagemaker_session=sagemaker_session, tc=run2._trial_component + ) + + # Add retry to make sure metrics -> eureka propagation is consistent + retry_with_backoff(verify_load_run, 4) + + +def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session): + exp_name1 = f"my-two-local-exp1-{name()}" + exp_name2 = f"my-two-local-exp2-{name()}" + run_name = "test-run" + with cleanup_exp_resources( + exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session + ): + # Run name is not provided, will create a new TC + with Run( + experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session + ) as run1: + pass + with Run( + experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session + ) as run2: + pass + + assert run1.experiment_name != run2.experiment_name + assert run1.run_name == run2.run_name + assert ( + run1._trial_component.trial_component_name != run2._trial_component.trial_component_name + ) + assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}" + assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}" + + +@pytest.mark.parametrize( + "input_names", + [ + (f"my-local-exp-{name()}", "test-run", None), # both have delimiter - + ("my-test-1", "my-test-1", None), # exp_name equals run_name + ("my-test-3", "my-test-3-run", None), # is subset of run_name + ("x" * 59, "test-run", None), # long exp_name + ("test-exp", "y" * 59, None), # long run_name + ("e" * 59, "y" * 59, None), # long exp_name and run_name + ("my-test4", "test-run", "run-display-name-test"), # with supplied display name + ], +) +def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names): + exp_name, run_name, run_display_name = input_names + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + run_name=run_name, + run_display_name=run_display_name, + ) as run1: + assert not run1._experiment.tags + assert not run1._trial.tags + is_run_tc = is_run_trial_component( + trial_component_name=run1._trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert is_run_tc + + with load_run( + experiment_name=exp_name, + run_name=run_name, + sagemaker_session=sagemaker_session, + ) as run2: + assert run2.experiment_name == exp_name + assert run2.run_name == run_name + assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}" + assert run2._trial_component.display_name in ( + run_display_name, + run2._trial_component.trial_component_name, + ) + + +_EXP_NAME_BASE_IN_SCRIPT = "job-exp-in-script" +_RUN_NAME_IN_SCRIPT = "job-run-in-script" + +_EXP_DIR = os.path.join(DATA_DIR, "experiment") +_ENTRY_POINT_PATH = os.path.join(_EXP_DIR, "train_job_script_for_run_clz.py") +_PYTHON_PROCESS_SCRIPT = "process_job_script_for_run_clz.py" +_TRANSFORM_MATERIALS = os.path.join(_EXP_DIR, "transform_job_materials") + +_RUN_INIT = "init" +_RUN_LOAD = "load" + + +def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, the same exp and run names are given in the Run constructor + # which will load the 1st Run TC in training job and log parameters + # and metrics there + # 3. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name == exp_name + assert run.run_name == _RUN_NAME_IN_SCRIPT + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.environment["CALL_RUN_LOAD_WITH_NO_NAME_ARGS"] = "True" + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + is_init=False, + has_extra_load=True, + ) + + +def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. The 1st Run TC created locally and its exp config was auto passed to the job + # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given + # in the Run constructor which will create a Run TC according to the run_name + # passed in there and ignore the exp config in the job + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the 2nd Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources( + exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session + ): + with Run( + experiment_name=exp_name2, + run_name=f"{_RUN_NAME_IN_SCRIPT}2", + sagemaker_session=sagemaker_session, + ) as run: + init_start_time = _check_tc_status_when_entering(run._trial_component) + # experiment_config is auto passed in by _RunContext + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + ) + + old_end_time = _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + sagemaker_session=sagemaker_session, + ) + assert run.experiment_name != exp_name + assert run.run_name != _RUN_NAME_IN_SCRIPT + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + with run: + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_tc_status_intermediate( + trial_component=run._trial_component, + sagemaker_session=sagemaker_session, + init_start_time=init_start_time, + old_end_time=old_end_time, + ) + + _check_tc_status_when_exiting( + trial_component_name=run._trial_component.trial_component_name, + init_start_time=init_start_time, + old_end_time=old_end_time, + sagemaker_session=sagemaker_session, + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar): + # Notes: + # 1. No Run TC created locally or specified in experiment config + # 2. In training job, Run is initialized + # which will create a Run TC according to the run_name passed in there + # 3. Both metrics and parameters are logged in the Run TC created in job + # 4. In a different training job, load the same Run TC and log more parameters there. + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + estimator = _generate_estimator( + sdk_tar=dev_sdk_tar, + sagemaker_session=sagemaker_session, + exp_name=exp_name, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, + sagemaker_session=sagemaker_session, + ) + + estimator.environment["RUN_OPERATION"] = _RUN_LOAD + estimator.fit( + job_name=f"train-job-{name()}", + wait=True, # wait the training job to finish + logs="None", # set to "All" to display logs fetched from the training job + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_processing_job_and_override_default_exp_config( + sagemaker_session, dev_sdk_tar, run_obj +): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run + # but override the default experiment config in context of 2nd Run TC + # with the experiment config of the 1st Run TC + # 3. In the processing job script, load the 1st Run TC via the experiment config + # fetched from the job env + # 4. All data are logged in the Run TC either locally or in the processing job + exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT) + processor = FrameworkProcessor( + estimator_cls=PyTorch, + framework_version="1.10", + py_version="py38", + instance_count=1, + instance_type="ml.m5.xlarge", + role=EXECUTION_ROLE, + sagemaker_session=sagemaker_session, + ) + + with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session): + with Run( + experiment_name=exp_name, + run_name=_RUN_NAME_IN_SCRIPT, + sagemaker_session=sagemaker_session, + ) as run: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + + with run_obj: + # Override the default experiment_config in _RunContext of run_obj + # with the experiment_config of run + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + experiment_config=run.experiment_config, + ) + + assert run_obj.experiment_name != run.experiment_name + assert run_obj.run_name != run.run_name + _check_run_from_local_end_result( + tc=run._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run.experiment_name, run_name=run.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + with run_obj: + # Not to override the exp config and use the default one in the context + processor.run( + code=_PYTHON_PROCESS_SCRIPT, + source_dir=_EXP_DIR, + job_name=f"process-job-{name()}", + wait=True, # wait the job to finish + logs=False, + ) + + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result( + tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False + ) + + +# dev_sdk_tar is required to trigger generating the dev SDK tar +def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version): + # Notes: + # 1. The 1st Run TC (run) created locally + # 2. In the inference script running in a transform job, load the 1st Run TC + # via explicitly passing the experiment_name and run_name of the 1st Run TC + # TODO: once we're able to retrieve exp config from the transform job env, + # we should expand this test and add the load_run() without explicitly supplying the names + # 3. All data are logged in the Run TC either locally or in the transform job + xgb_model_data_s3 = sagemaker_session.upload_data( + path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + xgboost_model = XGBoostModel( + sagemaker_session=sagemaker_session, + model_data=xgb_model_data_s3, + role=EXECUTION_ROLE, + entry_point="inference.py", + source_dir=_EXP_DIR, + framework_version=xgboost_latest_version, + env={ + "EXPERIMENT_NAME": run_obj.experiment_name, + "RUN_NAME": run_obj.run_name, + }, + ) + transformer = xgboost_model.transformer( + instance_count=1, + instance_type="ml.m5.4xlarge", + max_concurrent_transforms=5, + max_payload=1, + strategy="MultiRecord", + ) + uri = "s3://{}/{}/input/data/{}".format( + sagemaker_session.default_bucket(), + "transform-test", + unique_name_from_base("json-data"), + ) + input_data = S3Uploader.upload( + os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session + ) + + with run_obj: + _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session) + transformer.transform( + data=input_data, + content_type="text/libsvm", + split_type="Line", + wait=True, + job_name=f"transform-job-{name()}", + ) + + _check_run_from_local_end_result( + tc=run_obj._trial_component, + sagemaker_session=sagemaker_session, + is_complete_log=False, + ) + tc_name = Run._generate_trial_component_name( + experiment_name=run_obj.experiment_name, run_name=run_obj.run_name + ) + _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False) + + +def test_list(run_obj, sagemaker_session): + tc1 = _TrialComponent.create( + trial_component_name=f"non-run-tc1-{name()}", + sagemaker_session=sagemaker_session, + ) + tc2 = _TrialComponent.create( + trial_component_name=f"non-run-tc2-{name()}", + sagemaker_session=sagemaker_session, + tags=TAGS, + ) + run_obj._trial.add_trial_component(tc1) + run_obj._trial.add_trial_component(tc2) + + run_tcs = list_runs( + experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session + ) + assert len(run_tcs) == 1 + assert run_tcs[0].run_name == run_obj.run_name + assert run_tcs[0].experiment_name == run_obj.experiment_name + assert run_tcs[0].experiment_config == run_obj.experiment_config + + +def _generate_estimator(exp_name, sdk_tar, sagemaker_session): + return SKLearn( + framework_version="0.23-1", + entry_point=_ENTRY_POINT_PATH, + dependencies=[sdk_tar], + role=EXECUTION_ROLE, + instance_type="ml.m5.large", + instance_count=1, + volume_size=10, + max_run=900, + enable_sagemaker_metrics=True, + environment={ + "EXPERIMENT_NAME": exp_name, + "RUN_NAME": _RUN_NAME_IN_SCRIPT, + "RUN_OPERATION": _RUN_INIT, + }, + sagemaker_session=sagemaker_session, + ) + + +def _local_run_log_behaviors( + sagemaker_session, + artifact_file_path=None, + is_complete_log=True, +): + with load_run(sagemaker_session=sagemaker_session) as run: + run.log_parameter("pa", 1.0) + run.log_parameter("pb", "p2-value") + run.log_parameters({"pc": 2.0, "pd": "p4-value"}) + + if is_complete_log: + run.log_file(file_path=artifact_file_path, name=file_artifact_name) + run.log_artifact(name=artifact_name, value="s3://Output") + run.log_artifact(name=artifact_name, value="s3://Input", is_output=False) + + for i in range(BATCH_SIZE): + run.log_metric(name=metric_name, value=i, step=i) + + +def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True): + assert tc.parameters == {"pa": 1.0, "pb": "p2-value", "pc": 2.0, "pd": "p4-value"} + + if not is_complete_log: + return + + s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}" + assert s3_prefix in tc.output_artifacts[file_artifact_name].value + assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type + assert "s3://Output" == tc.output_artifacts[artifact_name].value + assert not tc.output_artifacts[artifact_name].media_type + assert "s3://Input" == tc.input_artifacts[artifact_name].value + assert not tc.input_artifacts[artifact_name].media_type + + # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod + assert len(tc.metrics) > 0 + metric_summary = tc.metrics[0] + assert metric_summary.metric_name == metric_name + assert metric_summary.max == 9.0 + assert metric_summary.min == 0.0 + + +def _check_run_from_job_result(sagemaker_session, tc_name=None, is_init=True, has_extra_load=False): + def validate_tc_updated_in_init(): + assert tc.start_time + assert tc.end_time + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert tc.parameters["p1"] == 1.0 + assert tc.parameters["p2"] == 2.0 + # TODO: revert to assert len(tc.metrics) == 5 once + # backend fix hits prod + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + # metrics deletion is not supported at this point + # so its count would accumulate + assert metric_summary.count > 0 + assert metric_summary.min == 0.0 + assert metric_summary.max == 1.0 + + def validate_tc_updated_in_load(): + assert tc.parameters["p3"] == 3.0 + assert tc.parameters["p4"] == 4.0 + assert len(tc.metrics) > 0 + for metric_summary in tc.metrics: + if metric_summary.metric_name != "test-job-load-log-metric": + continue + assert metric_summary.last == 0.1 + assert metric_summary.max == 0.1 + assert metric_summary.min == 0.1 + if has_extra_load: + assert tc.parameters["p5"] == 5.0 + assert tc.parameters["p6"] == 6.0 + + tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session) + if is_init: + # Add retry since the load behavior is inconsistent sometimes + retry_with_backoff(validate_tc_updated_in_init, 4) + else: + retry_with_backoff(validate_tc_updated_in_load, 4) + + +def _check_tc_status_when_entering(trial_component): + assert isinstance(trial_component.start_time, datetime.datetime) + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + return trial_component.start_time + + +def _check_tc_status_when_exiting( + trial_component_name, sagemaker_session, init_start_time, old_end_time=None +): + tc = _TrialComponent.load( + trial_component_name=trial_component_name, sagemaker_session=sagemaker_session + ) + # There will be deviation (< 1s) caused by different TS precisions used in Backend and SDK + assert abs(tc.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(tc.end_time, datetime.datetime) + if old_end_time: + assert tc.end_time > old_end_time + return tc.end_time + + +def _check_tc_status_intermediate( + trial_component, sagemaker_session, init_start_time, old_end_time=None +): + tc_load = _TrialComponent.load( + trial_component_name=trial_component.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert abs(tc_load.start_time.timestamp() - init_start_time.timestamp()) < 1 + assert tc_load.status.primary_status == _TrialComponentStatusType.InProgress.value + if not old_end_time: + assert not trial_component.end_time + return + assert isinstance(tc_load.end_time, datetime.datetime) + assert tc_load.end_time == old_end_time diff --git a/tests/integ/sagemaker/experiments/test_trial.py b/tests/integ/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..08f646c086 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import logging + +from sagemaker.experiments import trial +from src.sagemaker.utils import retry_with_backoff + + +def test_create_delete(trial_obj): + # Fixture creates / deletes, just ensure used at least once. + assert trial_obj.trial_name + + +def test_create_tags(trial_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_obj.tags + + +def test_save_load(trial_obj, sagemaker_session): + trial_obj.display_name = "foo" + trial_obj.save() + assert ( + "foo" + == trial._Trial.load( + trial_name=trial_obj.trial_name, + sagemaker_session=sagemaker_session, + ).display_name + ) + + +def test_add_remove_trial_component(trial_obj, trial_component_obj): + trial_obj.add_trial_component(trial_component_obj) + logging.info( + f"Added trial component {trial_component_obj.trial_component_name} to trial {trial_obj.trial_name}" + ) + + def validate_add(): + trial_components = list(trial_obj.list_trial_components()) + assert 1 == len( + trial_components + ), "Expected trial component to be included in trials list of TC" + + retry_with_backoff(validate_add) + + trial_obj.remove_trial_component(trial_component_obj) + logging.info( + f"Removed trial component {trial_component_obj.trial_component_name} from trial {trial_obj.trial_name}" + ) + + def validate_remove(): + trial_components = list(trial_obj.list_trial_components()) + assert 0 == len( + trial_components + ), "Expected trial component to be removed from trials list of TC" + + retry_with_backoff(validate_remove) diff --git a/tests/integ/sagemaker/experiments/test_trial_component.py b/tests/integ/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..3d79e41cc4 --- /dev/null +++ b/tests/integ/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import uuid + +from sagemaker.experiments._api_types import _TrialComponentStatusType +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments import _api_types, trial_component +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression + + +def test_create_delete(trial_component_obj): + # Fixture does create / delete, just need to ensure called at least once + assert trial_component_obj.trial_component_name + assert trial_component_obj.input_artifacts == {} + assert trial_component_obj.parameters == {} + assert trial_component_obj.output_artifacts == {} + + +def test_create_tags(trial_component_obj, sagemaker_session): + client = sagemaker_session.sagemaker_client + while True: + actual_tags = client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"] + if actual_tags: + break + for tag in actual_tags: + if "aws:tag" in tag.get("Key"): + actual_tags.remove(tag) + assert actual_tags == trial_component_obj.tags + + +def test_delete_with_force_disassociate( + trial_component_with_force_disassociation_obj, sagemaker_session +): + assert trial_component_with_force_disassociation_obj.trial_component_name + trials = sagemaker_session.sagemaker_client.list_trials( + TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name + )["TrialSummaries"] + assert len(trials) == 3 + + +def test_save(trial_component_obj, sagemaker_session): + trial_component_obj.display_name = str(uuid.uuid4()) + trial_component_obj.status = _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="Message" + ) + trial_component_obj.start_time = datetime.datetime.now( + datetime.timezone.utc + ) - datetime.timedelta(days=1) + trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc) + trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1} + trial_component_obj.input_artifacts = { + "snizz": _api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"), + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"), + } + trial_component_obj.output_artifacts = { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"), + "fly2": _api_types.TrialComponentArtifact( + value="s3:/sky/far2", media_type="away/tomorrow2" + ), + } + trial_component_obj.parameters_to_remove = ["foo"] + trial_component_obj.input_artifacts_to_remove = ["snizz"] + trial_component_obj.output_artifacts_to_remove = ["fly2"] + + trial_component_obj.save() + + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + + assert trial_component_obj.trial_component_name == loaded.trial_component_name + assert trial_component_obj.status == loaded.status + + assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1) + assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1) + + assert loaded.parameters == {"whizz": 100.1} + assert loaded.input_artifacts == { + "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2") + } + assert loaded.output_artifacts == { + "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow") + } + + +def test_load(trial_component_obj, sagemaker_session): + loaded = trial_component._TrialComponent.load( + trial_component_name=trial_component_obj.trial_component_name, + sagemaker_session=sagemaker_session, + ) + assert trial_component_obj.trial_component_arn == loaded.trial_component_arn + + +def test_list_sort(trial_components, sagemaker_session): + slack = datetime.timedelta(minutes=1) + now = datetime.datetime.now(datetime.timezone.utc) + trial_component_names = [tc.trial_component_name for tc in trial_components] + + for sort_order in ["Ascending", "Descending"]: + trial_component_names_listed = [ + s.trial_component_name + for s in trial_component._TrialComponent.list( + created_after=now - slack, + created_before=now + slack, + sort_by="CreationTime", + sort_order=sort_order, + sagemaker_session=sagemaker_session, + ) + if s.trial_component_name in trial_component_names + ] + + if sort_order == "Descending": + trial_component_names_listed = trial_component_names_listed[::-1] + assert trial_component_names == trial_component_names_listed + assert trial_component_names # sanity test + + +def test_search(sagemaker_session): + trial_component_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for s in trial_component._TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + trial_component_names_searched.append(s.trial_component_name) + + assert len(trial_component_names_searched) > 0 + assert trial_component_names_searched # sanity test diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 3c416ffd36..abfe6f6d0d 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -26,6 +26,7 @@ artifact, ) from sagemaker.model import ModelPackage +from sagemaker.utils import retry_with_backoff from tests.integ.sagemaker.workflow.test_workflow import ( test_end_to_end_pipeline_successful_execution, ) @@ -43,7 +44,7 @@ ) from sagemaker.lineage.lineage_trial_component import LineageTrialComponent -from tests.integ.sagemaker.lineage.helpers import name, names, retry +from tests.integ.sagemaker.lineage.helpers import name, names SLEEP_TIME_SECONDS = 1 SLEEP_TIME_TWO_SECONDS = 2 @@ -400,7 +401,7 @@ def model_obj(sagemaker_session): yield model time.sleep(SLEEP_TIME_SECONDS) - retry(lambda: model.delete(disassociate=True), num_attempts=4) + retry_with_backoff(lambda: model.delete(disassociate=True), num_attempts=4) @pytest.fixture diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..5548c63cff 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -15,7 +15,6 @@ import uuid from datetime import datetime -import time def name(): @@ -33,19 +32,6 @@ def names(): ] -def retry(callable, num_attempts=8): - assert num_attempts >= 1 - for i in range(num_attempts): - try: - return callable() - except Exception as ex: - if i == num_attempts - 1: - raise ex - print("Retrying", ex) - time.sleep(2**i) - assert False, "logic error in retry" - - def traverse_graph_back(start_arn, sagemaker_session): def visit(arn, visited: set): visited.add(arn) diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py index c629fcdc30..1980b51da2 100644 --- a/tests/integ/sagemaker/lineage/test_artifact.py +++ b/tests/integ/sagemaker/lineage/test_artifact.py @@ -20,7 +20,7 @@ import pytest from sagemaker.lineage import artifact -from tests.integ.sagemaker.lineage.helpers import retry +from sagemaker.utils import retry_with_backoff def test_create_delete(artifact_obj): @@ -125,7 +125,7 @@ def validate(): assert len(trials) == 1 assert trial_obj.trial_name in trials - retry(validate, num_attempts=3) + retry_with_backoff(validate, num_attempts=3) def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session): diff --git a/tests/integ/sagemaker/utilities/__init__.py b/tests/integ/sagemaker/utilities/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integ/sagemaker/utilities/test_search_expression.py b/tests/integ/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..ea7f4476bf --- /dev/null +++ b/tests/integ/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,67 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX +from sagemaker.experiments.trial_component import _TrialComponent +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression, NestedFilter + + +def test_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + search_expression = SearchExpression(filters=[search_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed") +def test_nested_search(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + nested_filter = NestedFilter(property_name="TrialComponentName", filters=[search_filter]) + search_expression = SearchExpression(nested_filters=[nested_filter]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched + + +def test_sub_expression(sagemaker_session): + tc_names_searched = [] + search_filter = Filter( + name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX + ) + sub_expression = SearchExpression(filters=[search_filter]) + search_expression = SearchExpression(sub_expressions=[sub_expression]) + for tc in _TrialComponent.search( + search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session + ): + tc_names_searched.append(tc.trial_component_name) + + assert len(tc_names_searched) > 0 + assert tc_names_searched diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py index b9ff13c50e..28b537c1ea 100644 --- a/tests/integ/test_marketplace.py +++ b/tests/integ/test_marketplace.py @@ -23,6 +23,7 @@ import sagemaker import tests.integ +from tests.integ.utils import create_repository from sagemaker import AlgorithmEstimator, ModelPackage, Model from sagemaker.serializers import CSVSerializer from sagemaker.tuner import IntegerParameter, HyperparameterTuner @@ -33,7 +34,6 @@ from tests.integ.test_multidatamodel import ( _ecr_image_uri, _ecr_login, - _create_repository, _delete_repository, ) from tests.integ.retry import retries @@ -214,7 +214,7 @@ def iris_image(sagemaker_session): rm=True, ) image.tag(ecr_image, tag="latest") - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 78ba62c3db..d6c14037a7 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -19,8 +19,8 @@ import docker import numpy import pytest -from botocore.exceptions import ClientError +from tests.integ.utils import create_repository from sagemaker import utils from sagemaker.amazon.randomcutforest import RandomCutForest from sagemaker.deserializers import StringDeserializer @@ -59,7 +59,7 @@ def container_image(sagemaker_session): image.tag(ecr_image, tag="latest") # Create AWS ECR and push the local docker image to it - _create_repository(ecr_client, algorithm_name) + create_repository(ecr_client, algorithm_name) # Retry docker image push for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10): @@ -90,23 +90,6 @@ def _ecr_image_uri(sagemaker_session, algorithm_name): return "{}.dkr.{}/{}:latest".format(account_id, endpoint_data["hostname"], algorithm_name) -def _create_repository(ecr_client, repository_name): - """ - Creates an ECS Repository (ECR). When a new transform is being registered, - we'll need a repository to push the image (and composed model images) to - """ - try: - response = ecr_client.create_repository(repositoryName=repository_name) - return response["repository"]["repositoryUri"] - except ClientError as e: - # Handle when the repository already exists - if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): - response = ecr_client.describe_repositories(repositoryNames=[repository_name]) - return response["repositories"][0]["repositoryUri"] - else: - raise - - def _delete_repository(ecr_client, repository_name): """ Deletes an ECS Repository (ECR). After the integration test completes diff --git a/tests/integ/utils.py b/tests/integ/utils.py index 53440f96f5..d7891321f2 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -14,6 +14,8 @@ import logging from functools import wraps +from botocore.exceptions import ClientError + from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS from sagemaker.exceptions import CapacityError @@ -69,3 +71,21 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def create_repository(ecr_client, repository_name): + """Creates an ECS Repository (ECR). + + When a new transform is being registered, + we'll need a repository to push the image (and composed model images) to + """ + try: + response = ecr_client.create_repository(repositoryName=repository_name) + return response["repository"]["repositoryUri"] + except ClientError as e: + # Handle when the repository already exists + if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"): + response = ecr_client.describe_repositories(repositoryNames=[repository_name]) + return response["repositories"][0]["repositoryUri"] + else: + raise diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 0000000000..21fe49cc97 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import sagemaker + +from mock import Mock, PropertyMock + +_ROLE = "DummyRole" +_REGION = "us-west-2" +_DEFAULT_BUCKET = "my-bucket" + + +@pytest.fixture(scope="session") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="session") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=_ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=_REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="session") +def sagemaker_session(boto_session, client): + # ideally this would mock Session instead of instantiating it + # most unit tests do mock the session correctly + return sagemaker.session.Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=_DEFAULT_BUCKET, + sagemaker_metrics_client=client, + ) diff --git a/tests/unit/sagemaker/experiments/__init__.py b/tests/unit/sagemaker/experiments/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py new file mode 100644 index 0000000000..4d33ad759d --- /dev/null +++ b/tests/unit/sagemaker/experiments/conftest.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import unittest +from unittest.mock import patch, MagicMock, Mock + +import pytest + +from sagemaker import Session +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import RUN_NAME_BASE +from sagemaker.experiments import Run +from tests.unit.sagemaker.experiments.helpers import ( + mock_tc_load_or_create_func, + mock_trial_load_or_create_func, + TEST_EXP_NAME, +) + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = unittest.mock.Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def sagemaker_session(client): + return Session( + sagemaker_client=client, + ) + + +@pytest.fixture +def run_obj(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.update_trial_component.return_value = {} + client.associate_trial_component.return_value = {} + with patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock( + return_value=_Experiment( + experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session + ) + ), + ): + with patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), + ): + with patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), + ): + run = Run( + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) + run._artifact_uploader = Mock() + run._lineage_artifact_tracker = Mock() + run._metrics_manager = Mock() + + assert run.run_name.startswith(RUN_NAME_BASE) + assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + + return run diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py new file mode 100644 index 0000000000..b7914010e5 --- /dev/null +++ b/tests/unit/sagemaker/experiments/helpers.py @@ -0,0 +1,44 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +TEST_EXP_NAME = "my-experiment" +TEST_RUN_NAME = "my-run" + + +def mock_tc_load_or_create_func( + trial_component_name, display_name=None, tags=None, sagemaker_session=None +): + tc = _TrialComponent( + trial_component_name=trial_component_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) + return tc, True + + +def mock_trial_load_or_create_func( + experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None +): + return _Trial( + trial_name=trial_name, + experiment_name=experiment_name, + display_name=display_name, + tags=tags, + sagemaker_session=sagemaker_session, + ) diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py new file mode 100644 index 0000000000..8bb23db7b6 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_environment.py @@ -0,0 +1,107 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile +import unittest.mock + +import pytest + +from sagemaker.experiments import _environment +from sagemaker.utils import retry_with_backoff + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def training_job_env(): + old_value = os.environ.get("TRAINING_JOB_ARN") + os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe" + yield os.environ + del os.environ["TRAINING_JOB_ARN"] + if old_value: + os.environ["TRAINING_JOB_ARN"] = old_value + + +@pytest.fixture +def transform_job_env(): + old_value = os.environ.get("SAGEMAKER_BATCH") + os.environ["SAGEMAKER_BATCH"] = "true" + yield os.environ + del os.environ["SAGEMAKER_BATCH"] + if old_value: + os.environ["SAGEMAKER_BATCH"] = old_value + + +def test_processing_job_environment(tempdir): + config_path = os.path.join(tempdir, "config.json") + with open(config_path, "w") as f: + f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"})) + environment = _environment._RunEnvironment.load(processing_job_config_path=config_path) + + assert _environment._EnvironmentType.SageMakerProcessingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_training_job_environment(training_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTrainingJob == environment.environment_type + assert "arn:1234aBcDe" == environment.source_arn + + +def test_transform_job_environment(transform_job_env): + environment = _environment._RunEnvironment.load() + assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type + # TODO: update if we figure out how to get source_arn from the transform job + assert not environment.source_arn + + +def test_no_environment(): + assert _environment._RunEnvironment.load() is None + + +def test_resolve_trial_component(training_job_env, sagemaker_session): + trial_component_name = "foo-bar" + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = { + "TrialComponentSummaries": [{"TrialComponentName": trial_component_name}] + } + client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name} + environment = _environment._RunEnvironment.load() + tc = environment.get_trial_component(sagemaker_session) + + assert trial_component_name == tc.trial_component_name + client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name) + client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde") + + +@unittest.mock.patch("sagemaker.experiments._environment.retry_with_backoff") +def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_job_env): + mock_retry.side_effect = lambda func: retry_with_backoff(func, 2) + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = Exception("Failed test") + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None + + +def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session): + environment = _environment._RunEnvironment.load() + assert environment.get_trial_component(sagemaker_session) is None diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py new file mode 100644 index 0000000000..b0ad55c27f --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -0,0 +1,306 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest +import unittest.mock +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import experiment +from sagemaker.experiments._api_types import TrialSummary + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_experiment.return_value = {"Description": "description-value"} + experiment_obj = experiment._Experiment.load( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + assert experiment_obj.description == "description-value" + + client.describe_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value") + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_experiment.return_value = {"Arn": "arn:aws:1234"} + tags = [{"Key": "foo", "Value": "bar"}] + experiment_obj = experiment._Experiment.create( + experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags + ) + assert experiment_obj.experiment_name == "name-value" + client.create_experiment.assert_called_with(ExperimentName="name-value", Tags=tags) + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.update_experiment.return_value = {} + obj.save() + client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar") + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client.delete_experiment.return_value = {} + obj.delete() + client.delete_experiment.assert_called_with(ExperimentName="foo") + + +@patch("sagemaker.experiments.experiment._Experiment.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + exp_name = "exp_name" + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(exp_name, sagemaker_session) + + +@patch("sagemaker.experiments.experiment._Experiment.load") +@patch("sagemaker.experiments.experiment._Experiment.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + experiment._Experiment._load_or_create( + experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + experiment_name=exp_name, + display_name=None, + description=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []} + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + assert list(experiment_obj.list_trials()) == [] + + +def test_list_trials_single(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj} + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary(name="trial-foo", creation_time=datetime_obj, last_modified_time=datetime_obj) + ] + + +def test_list_trials_two_values(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + sagemaker_session.sagemaker_client.list_trials.return_value = { + "TrialSummaries": [ + {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + {"Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, + ] + } + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + +def test_next_token(sagemaker_session, datetime_obj): + experiment_obj = experiment._Experiment(sagemaker_session) + client = sagemaker_session.sagemaker_client + client.list_trials.side_effect = [ + { + "TrialSummaries": [ + { + "Name": "trial-foo-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "Name": "trial-foo-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ], + "NextToken": "foo", + }, + { + "TrialSummaries": [ + { + "Name": "trial-foo-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + } + ] + }, + ] + + assert list(experiment_obj.list_trials()) == [ + TrialSummary( + name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + + client.list_trials.assert_any_call(**{}) + client.list_trials.assert_any_call(NextToken="foo") + + +def test_list_trials_call_args(sagemaker_session): + client = sagemaker_session.sagemaker_client + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + client.list_trials.return_value = {} + assert [] == list( + experiment_obj.list_trials(created_after=created_after, created_before=created_before) + ) + client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after) + + +def test_delete_all_with_incorrect_action_name(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + with pytest.raises(ValueError) as err: + obj._delete_all(action="abc") + + assert "Must confirm with string '--force'" in str(err) + + +def test_delete_all(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + client.describe_trial.side_effect = [ + {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"}, + {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"}, + ] + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "trial-component-3", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialComponentName": "trial-component-4", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + }, + ] + + client.describe_trial_component.side_effect = [ + {"TrialComponentName": "trial-component-1"}, + {"TrialComponentName": "trial-component-2"}, + {"TrialComponentName": "trial-component-3"}, + {"TrialComponentName": "trial-component-4"}, + ] + + client.delete_trial_component.return_value = {} + client.delete_trial.return_value = {} + client.delete_experiment.return_value = {} + + obj._delete_all(action="--force") + + client.delete_experiment.assert_called_with(ExperimentName="foo") + + delete_trial_expected_calls = [ + unittest.mock.call(TrialName="trial-1"), + unittest.mock.call(TrialName="trial-2"), + ] + assert delete_trial_expected_calls == client.delete_trial.mock_calls + + delete_trial_component_expected_calls = [ + unittest.mock.call(TrialComponentName="trial-component-1"), + unittest.mock.call(TrialComponentName="trial-component-2"), + unittest.mock.call(TrialComponentName="trial-component-3"), + unittest.mock.call(TrialComponentName="trial-component-4"), + ] + assert delete_trial_component_expected_calls == client.delete_trial_component.mock_calls + + +def test_delete_all_fail(sagemaker_session): + obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + sagemaker_session.sagemaker_client.list_trials.side_effect = Exception + with pytest.raises(Exception) as e: + obj._delete_all(action="--force") + + assert str(e.value) == "Failed to delete, please try again." diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py new file mode 100644 index 0000000000..a11f67389b --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_helper.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os +import shutil +import tempfile + +from mock import Mock, PropertyMock, call +import pytest + +from src.sagemaker.experiments._helper import ( + _LineageArtifactTracker, + _ArtifactUploader, +) +from src.sagemaker.experiments._utils import resolve_artifact_name +from src.sagemaker.session import Session + + +@pytest.fixture +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value="DummyRole") + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name="us-west-2") + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture +def sagemaker_session(client, boto_session): + return Session( + sagemaker_client=client, + boto_session=boto_session, + ) + + +@pytest.fixture +def lineage_artifact_tracker(sagemaker_session): + return _LineageArtifactTracker("test_trial_component_arn", sagemaker_session) + + +def test_lineage_artifact_tracker(lineage_artifact_tracker, sagemaker_session): + client = sagemaker_session.sagemaker_client + lineage_artifact_tracker.add_input_artifact( + "input_name", "input_source_uri", "input_etag", "text/plain" + ) + lineage_artifact_tracker.add_output_artifact( + "output_name", "output_source_uri", "output_etag", "text/plain" + ) + client.create_artifact.side_effect = [ + {"ArtifactArn": "created_arn_1"}, + {"ArtifactArn": "created_arn_2"}, + ] + + lineage_artifact_tracker.save() + + expected_calls = [ + call( + ArtifactName="input_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "input_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "input_etag"}], + }, + ), + call( + ArtifactName="output_name", + ArtifactType="text/plain", + Source={ + "SourceUri": "output_source_uri", + "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "output_etag"}], + }, + ), + ] + assert expected_calls == client.create_artifact.mock_calls + + expected_calls = [ + call( + SourceArn="created_arn_1", + DestinationArn="test_trial_component_arn", + AssociationType="ContributedTo", + ), + call( + SourceArn="test_trial_component_arn", + DestinationArn="created_arn_2", + AssociationType="Produced", + ), + ] + assert expected_calls == client.add_association.mock_calls + + +@pytest.fixture +def artifact_uploader(sagemaker_session): + return _ArtifactUploader( + trial_component_name="trial_component_name", + artifact_bucket="artifact_bucket", + artifact_prefix="artifact_prefix", + sagemaker_session=sagemaker_session, + ) + + +@pytest.fixture +def tempdir(): + tmp_dir = tempfile.mkdtemp() + yield tmp_dir + shutil.rmtree(tmp_dir) + + +def test_artifact_uploader_init(artifact_uploader): + assert "trial_component_name" == artifact_uploader.trial_component_name + assert "artifact_bucket" == artifact_uploader.artifact_bucket + assert "artifact_prefix" == artifact_uploader.artifact_prefix + + +def test_artifact_uploader_upload_artifact_file_not_exists(tempdir, artifact_uploader): + not_exist_file = os.path.join(tempdir, "not.exists") + with pytest.raises(ValueError) as error: + artifact_uploader.upload_artifact(not_exist_file) + assert "does not exist or is not a file" in str(error) + + +def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader): + path = os.path.join(tempdir, "exists") + with open(path, "a") as f: + f.write("boo") + + name = resolve_artifact_name(path) + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + s3_uri, etag = artifact_uploader.upload_artifact(path) + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.upload_file.assert_called_with( + path, artifact_uploader.artifact_bucket, expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri + + +def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader): + artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"} + + artifact_name = "my-artifact" + artifact_object = {"key": "value"} + file_extension = ".csv" + s3_uri, etag = artifact_uploader.upload_object_artifact( + artifact_name, artifact_object, file_extension + ) + name = artifact_name + file_extension + expected_key = "{}/{}/{}".format( + artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name + ) + + artifact_uploader._s3_client.put_object.assert_called_with( + Body=json.dumps(artifact_object), Bucket=artifact_uploader.artifact_bucket, Key=expected_key + ) + + expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key) + assert expected_uri == s3_uri diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py new file mode 100644 index 0000000000..21556f70fd --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_metrics.py @@ -0,0 +1,178 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import os +import pytest +import tempfile +import shutil +import datetime +import dateutil +import json +import time + +from sagemaker.experiments._metrics import ( + _RawMetricData, + _SageMakerFileMetricsWriter, + SageMakerMetricsWriterException, +) + + +@pytest.fixture +def tempdir(): + dir = tempfile.mkdtemp() + yield dir + shutil.rmtree(dir) + + +@pytest.fixture +def filepath(tempdir): + return os.path.join(tempdir, "foo.json") + + +@pytest.fixture +def timestamp(): + return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1) + + +def test_raw_metric_data_utc_timestamp(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_utc_(): + utcnow = datetime.datetime.now(datetime.timezone.utc) + assert utcnow.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow) + assert utcnow.timestamp() == metric.Timestamp + + +def test_raw_metric_data_aware_timestamp(): + aware_datetime = datetime.datetime.now(dateutil.tz.gettz("America/Chicago")) + assert aware_datetime.tzinfo + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=aware_datetime) + assert (aware_datetime - aware_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_naive_timestamp(): + naive_datetime = datetime.datetime.now() + assert naive_datetime.tzinfo is None + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=naive_datetime) + local_datetime = naive_datetime.replace(tzinfo=dateutil.tz.tzlocal()) + assert (local_datetime - local_datetime.utcoffset()).replace( + tzinfo=datetime.timezone.utc + ).timestamp() == metric.Timestamp + + +def test_raw_metric_data_number_timestamp(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now) + assert time_now == metric.Timestamp + + +def test_raw_metric_data_request_item(): + time_now = time.time() + metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now, step=10) + expected = { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": str(int(time_now)), + "Step": 10, + } + assert expected == metric.to_raw_metric_data() + + +def test_raw_metric_data_invalid_timestamp(): + with pytest.raises(ValueError) as error1: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() - 2000000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error1) + + with pytest.raises(ValueError) as error2: + _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000) + assert "Timestamps must be between two weeks before and two hours from now" in str(error2) + + +def test_file_metrics_writer_log_metric(timestamp, filepath): + now = datetime.datetime.now(datetime.timezone.utc) + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.log_metric(metric_name="foo", value=2.0, step=1) + writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp) + writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2) + writer.close() + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines] + + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + assert (now.timestamp() - entry_one["Timestamp"]) < 1 + assert "Step" not in entry_one + + assert 1 == entry_two["Step"] + assert timestamp.timestamp() == entry_three["Timestamp"] + assert 2 == entry_four["Step"] + + +def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + + writer.log_metric(metric_name="foo", value=1.0) + + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one] = [json.loads(line) for line in lines] + assert "foo" == entry_one["MetricName"] + assert 1.0 == entry_one["Value"] + + writer.log_metric(metric_name="bar", value=2.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two] = [json.loads(line) for line in lines] + assert "bar" == entry_two["MetricName"] + assert 2.0 == entry_two["Value"] + + writer.log_metric(metric_name="biz", value=3.0) + lines = [x for x in open(filepath).read().split("\n") if x] + [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines] + assert "biz" == entry_three["MetricName"] + assert 3.0 == entry_three["Value"] + + writer.close() + + +def test_file_metrics_writer_context_manager(timestamp, filepath): + with _SageMakerFileMetricsWriter(filepath) as writer: + writer.log_metric("foo", value=1.0, timestamp=timestamp) + entry = json.loads(open(filepath, "r").read().strip()) + assert { + "MetricName": "foo", + "Value": 1.0, + "Timestamp": timestamp.timestamp(), + }.items() <= entry.items() + + +def test_file_metrics_writer_fail_write_on_close(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.log_metric(metric_name="foo", value=1.0) + writer.close() + with pytest.raises(SageMakerMetricsWriterException): + writer.log_metric(metric_name="foo", value=1.0) + + +def test_file_metrics_writer_no_write(filepath): + writer = _SageMakerFileMetricsWriter(filepath) + writer.close() + assert not os.path.exists(filepath) diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py new file mode 100644 index 0000000000..0e4ebee181 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -0,0 +1,941 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest +from math import inf, nan +from unittest.mock import patch, Mock, MagicMock + +import dateutil +import pytest + +from sagemaker.experiments import _environment, SortOrderType +from sagemaker.experiments._api_types import ( + TrialComponentArtifact, + TrialComponentSummary, + TrialComponentStatus, + _TrialComponentStatusType, + TrialComponentSearchResult, +) +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import ( + TRIAL_NAME_TEMPLATE, + MAX_RUN_TC_ARTIFACTS_LEN, + MAX_NAME_LEN_IN_BACKEND, + EXPERIMENT_NAME, + RUN_NAME, + TRIAL_NAME, + DELIMITER, + RUN_TC_TAG, + SortByType, +) +from sagemaker.experiments import Run, load_run, list_runs +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent +from tests.unit.sagemaker.experiments.helpers import ( + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, + TEST_EXP_NAME, + TEST_RUN_NAME, +) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save") +def test_run_init(mock_tc_save, sagemaker_session): + with Run( + experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session + ) as run_obj: + assert not run_obj._in_load + assert not run_obj._inside_load_context + assert run_obj._inside_init_context + assert not run_obj._trial_component.parameters + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial.trial_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME) + assert run_obj._experiment.experiment_name == TEST_EXP_NAME + assert run_obj.experiment_config == { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: run_obj.run_group_name, + RUN_NAME: expected_tc_name, + } + + # trail_component.save is called when entering/ exiting the with block + mock_tc_save.assert_called() + + +def test_run_init_name_length_exceed_limit(sagemaker_session): + invalid_name = "x" * MAX_NAME_LEN_IN_BACKEND + + # experiment_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=invalid_name, + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ) + + assert ( + f"The experiment_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" + in str(err) + ) + + # run_name exceeds + with pytest.raises(ValueError) as err: + Run( + experiment_name=TEST_EXP_NAME, + run_name=invalid_name, + sagemaker_session=sagemaker_session, + ) + + assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str( + err + ) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-train-job" + rv = Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + with load_run(sagemaker_session=sagemaker_session) as run_obj: + assert run_obj._in_load + assert not run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj._trial + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._experiment + assert run_obj.experiment_config == exp_config + + client.describe_training_job.assert_called_once_with(TrainingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg( + mock_run_env, sagemaker_session +): + rv = Mock() + rv.source_arn = "arn:1234/my-train-job" + rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob + mock_run_env.load.return_value = rv + + # No Run object is created else where + sagemaker_session.sagemaker_client.describe_training_job.return_value = { + "TrainingJobName": "train-job-experiments", + } + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err) + + +def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session): + with run_obj: + with load_run(sagemaker_session=sagemaker_session) as run: + assert run_obj == run + + +def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session): + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + # experiment_name is given but is not supplied along with the run_name so it's ignored. + with pytest.raises(RuntimeError) as err: + with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + pass + + assert "Failed to load a Run object" in str(err) + + +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_run_load_with_run_name_and_exp_name(sagemaker_session): + with load_run( + run_name=TEST_RUN_NAME, + experiment_name=TEST_EXP_NAME, + sagemaker_session=sagemaker_session, + ) as run_obj: + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + expected_exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + + assert run_obj.run_name == TEST_RUN_NAME + assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME) + assert run_obj.experiment_name == TEST_EXP_NAME + assert run_obj._trial_component.trial_component_name == expected_tc_name + assert run_obj._trial + assert run_obj._experiment + assert run_obj.experiment_config == expected_exp_config + + +def test_run_load_with_run_name_but_no_exp_name(sagemaker_session): + with pytest.raises(ValueError) as err: + with load_run( + run_name=TEST_RUN_NAME, + sagemaker_session=sagemaker_session, + ): + pass + + assert "Invalid input: experiment_name is missing" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +@patch.object(_TrialComponent, "save", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): + client = sagemaker_session.sagemaker_client + job_name = "my-process-job" + rv = unittest.mock.Mock() + rv.source_arn = f"arn:1234/{job_name}" + rv.environment_type = _environment._EnvironmentType.SageMakerProcessingJob + mock_run_env.load.return_value = rv + + expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}" + exp_config = { + EXPERIMENT_NAME: TEST_EXP_NAME, + TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME), + RUN_NAME: expected_tc_name, + } + client.describe_processing_job.return_value = { + "ProcessingJobName": "process-job-experiments", + # The Run object has been created else where + "ExperimentConfig": exp_config, + } + + with load_run(sagemaker_session=sagemaker_session): + pass + + client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name) + + +@patch("sagemaker.experiments.run._RunEnvironment") +def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): + # TODO: update this test once figure out how to get source_arn from transform job + rv = unittest.mock.Mock() + rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob + rv.source_arn = "" + mock_run_env.load.return_value = rv + + with pytest.raises(RuntimeError) as err: + with load_run(sagemaker_session=sagemaker_session): + pass + + assert ( + "loading experiment config from transform job environment is not currently supported" + ) in str(err) + + +def test_log_parameter_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameter("foo", "bar") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameter(run_obj): + with run_obj: + run_obj.log_parameter("foo", "bar") + assert run_obj._trial_component.parameters["foo"] == "bar" + run_obj.log_parameter("whizz", 1) + assert run_obj._trial_component.parameters["whizz"] == 1 + + +def test_log_parameter_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_parameter("key", nan) + assert "key" not in run_obj._trial_component.parameters + + +def test_log_parameters_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_parameters(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_parameters_skip_invalid_values(run_obj): + with run_obj: + run_obj.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan}) + assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5} + + +def test_log_input_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text", False) + assert run_obj._trial_component.input_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_output_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_artifact("foo", "baz", "text/text") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output(run_obj): + with run_obj: + run_obj.log_artifact("foo", "baz", "text/text") + assert run_obj._trial_component.output_artifacts == { + "foo": TrialComponentArtifact(value="baz", media_type="text/text") + } + + +def test_log_metric_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_metric(name="foo", value=1.0, step=1) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_metric(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj.log_metric(name="foo", value=1.0, step=1, timestamp=now) + run_obj._metrics_manager.log_metric.assert_called_with( + metric_name="foo", value=1.0, step=1, timestamp=now + ) + + +def test_log_metric_skip_invalid_value(run_obj): + with run_obj: + run_obj.log_metric(None, nan, None, None) + assert not run_obj._metrics_manager.log_metric.called + + +def test_log_metric_attribute_error(run_obj): + now = datetime.datetime.now() + with run_obj: + run_obj._metrics_manager.log_metric.side_effect = AttributeError + + with pytest.raises(AttributeError): + run_obj.log_metric("foo", 1.0, 1, now) + + +def test_log_output_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_output_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type + + run_obj.log_file("foo.txt") + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.output_artifacts + assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type + + +def test_log_input_artifact_outside_run_context(run_obj): + with pytest.raises(RuntimeError) as err: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_input_artifact(run_obj): + run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value") + with run_obj: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type + + run_obj.log_file("foo.txt", is_output=False) + run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt") + assert "foo.txt" in run_obj._trial_component.input_artifacts + assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type + + +def test_log_multiple_inputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.input_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang", False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_outputs(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._trial_component.output_artifacts[file_path] = { + "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text") + } + with pytest.raises(ValueError) as error: + run_obj.log_artifact("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_multiple_input_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file( + file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False + ) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an output artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=True) + + # log an extra input artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error) + + +def test_log_multiple_output_artifacts(run_obj): + with run_obj: + for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN): + file_path = "foo" + str(index) + ".txt" + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value" + str(index), + "etag_value" + str(index), + ) + run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index)) + run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path) + + run_obj._artifact_uploader.upload_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + # log an input artifact, should be fine + run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False) + + # log an extra output artifact, should raise exception + with pytest.raises(ValueError) as error: + run_obj.log_file("foo.txt", "name", "whizz/bang") + assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error) + + +def test_log_precision_recall_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + with pytest.raises(RuntimeError) as err: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_precision_recall(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + no_skill = 0.1 + title = "TestPrecisionRecall" + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_precision_recall( + y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False + ) + + expected_data = { + "type": "PrecisionRecallCurve", + "version": 0, + "title": title, + "precision": [0.5, 0.3333333333333333, 0.5, 0.0, 1.0], + "recall": [1.0, 0.5, 0.5, 0.0, 0.0], + "averagePrecisionScore": 0.5, + "noSkill": 0.1, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + title, expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name=title, + source_uri="s3uri_value", + etag="etag_value", + artifact_type="PrecisionRecallCurve", + ) + + +def test_log_precision_recall_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + no_skill = 0.1 + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_precision_recall( + y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False + ) + assert "Lengths mismatch between true labels and predicted probabilities" in str(error) + + +def test_log_confusion_matrix_outside_run_context(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + with pytest.raises(RuntimeError) as err: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_confusion_matrix(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0, 2] + + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + with run_obj: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + + expected_data = { + "type": "ConfusionMatrix", + "version": 0, + "title": "TestConfusionMatrix", + "confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]], + } + + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestConfusionMatrix", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_output_artifact.assert_called_with( + name="TestConfusionMatrix", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ConfusionMatrix", + ) + + +def test_log_confusion_matrix_invalid_input(run_obj): + y_true = [2, 0, 2, 2, 0, 1] + y_pred = [0, 0, 2, 2, 0] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix") + assert "Lengths mismatch between true labels and predicted labels" in str(error) + + +def test_log_roc_curve_outside_run_context(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + + with pytest.raises(RuntimeError) as err: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "This method should be called inside context of 'with' statement" in str(err) + + +def test_log_roc_curve(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35, 0.8] + with run_obj: + run_obj._artifact_uploader.upload_object_artifact.return_value = ( + "s3uri_value", + "etag_value", + ) + + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + + expected_data = { + "type": "ROCCurve", + "version": 0, + "title": "TestROCCurve", + "falsePositiveRate": [0.0, 0.0, 0.5, 0.5, 1.0], + "truePositiveRate": [0.0, 0.5, 0.5, 1.0, 1.0], + "areaUnderCurve": 0.75, + } + run_obj._artifact_uploader.upload_object_artifact.assert_called_with( + "TestROCCurve", expected_data, file_extension="json" + ) + + run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with( + name="TestROCCurve", + source_uri="s3uri_value", + etag="etag_value", + artifact_type="ROCCurve", + ) + + +def test_log_roc_curve_invalid_input(run_obj): + y_true = [0, 0, 1, 1] + y_scores = [0.1, 0.4, 0.35] + + with run_obj: + with pytest.raises(ValueError) as error: + run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False) + assert "Lengths mismatch between true labels and predicted scores" in str(error) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +@patch("sagemaker.experiments.run._TrialComponent.list") +@patch("sagemaker.experiments.run._TrialComponent.search") +def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + tc_list_len = 20 + tc_list_len_half = int(tc_list_len / 2) + mock_tc_search.side_effect = [ + [ + TrialComponentSearchResult( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + tags=[RUN_TC_TAG] if i < tc_list_len_half else None, + ) + ] + for i in range(tc_list_len) + ] + mock_tc_list.return_value = [ + TrialComponentSummary( + trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(tc_list_len) + ] + mock_tc_load.side_effect = [ + ( + _TrialComponent( + trial_component_name=Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ), + trial_component_arn="b" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ), + True, + ) + for i in range(tc_list_len_half) + ] + + run_list = list_runs( + experiment_name=TEST_EXP_NAME, + sort_by=SortByType.CREATION_TIME, + sort_order=SortOrderType.ASCENDING, + sagemaker_session=sagemaker_session, + ) + + mock_tc_list.assert_called_once_with( + experiment_name=TEST_EXP_NAME, + created_before=None, + created_after=None, + sort_by="CreationTime", + sort_order="Ascending", + sagemaker_session=sagemaker_session, + max_results=None, + next_token=None, + ) + assert len(run_list) == tc_list_len_half + for i in range(tc_list_len_half): + run = run_list[i] + assert run.experiment_name == TEST_EXP_NAME + assert run.run_name == "a" + str(i) + assert run._experiment + assert run._trial + assert isinstance(run._trial_component, _TrialComponent) + assert run._trial_component.trial_component_name == Run._generate_trial_component_name( + "a" + str(i), TEST_EXP_NAME + ) + assert run._in_load is False + assert run._inside_load_context is False + assert run._inside_init_context is False + assert run._artifact_uploader + assert run._lineage_artifact_tracker + assert run._metrics_manager + + +@patch("sagemaker.experiments.run._TrialComponent.list") +def test_list_empty(mock_tc_list, sagemaker_session): + mock_tc_list.return_value = [] + assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch("sagemaker.experiments.run._TrialComponent._load_or_create") +def test_enter_exit_locally(mock_load_tc, sagemaker_session, run_obj): + mock_load_tc.return_value = run_obj._trial_component, False + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + _verify_tc_status_before_enter_init(run_obj._trial_component) + + with run_obj: + _verify_tc_status_when_entering(run_obj._trial_component) + init_start_time = run_obj._trial_component.start_time + + with load_run(sagemaker_session=sagemaker_session): + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + ) + + old_end_time = _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, + old_end_time=old_end_time, + ) + + # Re-load to verify: + # 1. if it works when load_run and with are not in one line + # 2. if re-entering the load will change the "Completed" TC status + # to "InProgress" + # 3. when exiting the load, the end_time and status will be overridden again + run_load = load_run( + experiment_name=run_obj.experiment_name, + run_name=run_obj.run_name, + sagemaker_session=sagemaker_session, + ) + with run_load: + _verify_tc_status_when_entering( + trial_component=run_obj._trial_component, + init_start_time=init_start_time, + has_completed=True, + ) + _verify_tc_status_when_successfully_exit( + trial_component=run_obj._trial_component, old_end_time=old_end_time + ) + + +def test_exit_fail(sagemaker_session, run_obj): + sagemaker_session.sagemaker_client.update_trial_component.return_value = {} + try: + with run_obj: + raise ValueError("Foo") + except ValueError: + pass + + assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value + assert run_obj._trial_component.status.message + assert isinstance(run_obj._trial_component.end_time, datetime.datetime) + + +@pytest.mark.parametrize( + "metric_value", + [1.3, "nan", "inf", "-inf", None], +) +def test_is_input_valid(run_obj, metric_value): + assert run_obj._is_input_valid("metric", "Name", metric_value) + + +@pytest.mark.parametrize( + "metric_value", + [nan, inf, -inf], +) +def test_is_input_valid_false(run_obj, metric_value): + assert not run_obj._is_input_valid("parameter", "Name", metric_value) + + +def test_generate_trial_name(): + base_name = "x" * MAX_NAME_LEN_IN_BACKEND + trial_name = Run._generate_trial_name(base_name=base_name) + assert len(trial_name) <= MAX_NAME_LEN_IN_BACKEND + + +def test_append_run_tc_label_to_tags(): + expected_tc_tag = RUN_TC_TAG + + tags = None + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 1 + assert expected_tc_tag in ret + + tags = [{"Key": "foo", "Value": "bar"}] + ret = Run._append_run_tc_label_to_tags(tags) + assert len(ret) == 2 + assert expected_tc_tag in ret + + +def _verify_tc_status_before_enter_init(trial_component): + assert not trial_component.start_time + assert not trial_component.end_time + assert not trial_component.status + + +def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False): + if not init_start_time: + assert isinstance(trial_component.start_time, datetime.datetime) + now = datetime.datetime.now(dateutil.tz.tzlocal()) + assert (now.timestamp() - trial_component.start_time.timestamp()) < 1 + else: + assert trial_component.start_time == init_start_time + + if not has_completed: + assert not trial_component.end_time + assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value + + +def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None): + assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value + assert isinstance(trial_component.start_time, datetime.datetime) + assert isinstance(trial_component.end_time, datetime.datetime) + if old_end_time: + assert trial_component.end_time > old_end_time + return trial_component.end_time diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py new file mode 100644 index 0000000000..7e068136a1 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -0,0 +1,191 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import patch, MagicMock + +import pytest + +from sagemaker.estimator import Estimator, _TrainingJob +from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.run import _RunContext +from sagemaker.experiments import load_run, Run +from sagemaker.experiments.trial import _Trial +from tests.unit.sagemaker.experiments.helpers import ( + TEST_EXP_NAME, + mock_trial_load_or_create_func, + mock_tc_load_or_create_func, +) + +_bucket = "my-bucket" +_train_input_path = f"s3://{_bucket}/data.csv" +_train_output_path = f"s3://{_bucket}" + + +@patch.object(_TrainingJob, "start_new") +def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + expected_exp_config = run_obj.experiment_config + mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +@patch.object(_TrainingJob, "start_new") +def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session): + mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job") + supplied_exp_cfg = { + "ExperimentName": "my-supplied-exp-name", + "TrialName": "my-supplied-run-group-name", + "RunName": "my-supplied-run-name", + } + with run_obj: + estimator = Estimator( + role="arn:my-role", + image_uri="my-image", + sagemaker_session=sagemaker_session, + output_path=_train_output_path, + ) + estimator.fit( + experiment_config=supplied_exp_cfg, + inputs=_train_input_path, + wait=False, + ) + + assert _RunContext.get_current_run() == run_obj + + mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg) + + # _RunContext is cleaned up after exiting the with statement + assert not _RunContext.get_current_run() + + +def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session): + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + def train(): + with load_run(sagemaker_session=sagemaker_session) as run_load: + assert run_load == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + run_load.log_parameter("foo", "bar") + run_load.log_parameter("whizz", 1) + + with run_obj: + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + train() + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + run_obj.log_parameters({"a": "b", "c": 2}) + + assert run_obj._trial_component.parameters["foo"] == "bar" + assert run_obj._trial_component.parameters["whizz"] == 1 + assert run_obj._trial_component.parameters["a"] == "b" + assert run_obj._trial_component.parameters["c"] == 2 + + # Verify separating load_run and with statement in different lines still work + run_load2 = load_run(sagemaker_session=sagemaker_session) + with run_load2: + assert run_load2 == run_obj + assert run_obj._inside_init_context + assert run_obj._inside_load_context + assert run_obj._in_load + + assert run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert _RunContext.get_current_run() + + assert not run_obj._inside_init_context + assert not run_obj._inside_load_context + assert not run_obj._in_load + assert not _RunContext.get_current_run() + + +def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +@patch( + "sagemaker.experiments.run._Experiment._load_or_create", + MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), +) +@patch( + "sagemaker.experiments.run._Trial._load_or_create", + MagicMock(side_effect=mock_trial_load_or_create_func), +) +@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) +@patch( + "sagemaker.experiments.run._TrialComponent._load_or_create", + MagicMock(side_effect=mock_tc_load_or_create_func), +) +def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session): + assert _RunContext.get_current_run() + + with run_obj: + pass + assert "It is not allowed to use nested 'with' statements on the Run" in str(err) + + +def test_nested_run_load_context(run_obj, sagemaker_session): + assert not _RunContext.get_current_run() + + with pytest.raises(RuntimeError) as err: + with run_obj: + assert _RunContext.get_current_run() + + with load_run(): + run_load = load_run() + with run_load: + pass + assert "It is not allowed to use nested 'with' statements on the load_run" in str(err) diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py new file mode 100644 index 0000000000..f6996fefc3 --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial.py @@ -0,0 +1,276 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +import datetime + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments._api_types import TrialSummary +from sagemaker.experiments.trial import _Trial +from sagemaker.experiments.trial_component import _TrialComponent + + +@pytest.fixture +def datetime_obj(): + return datetime.datetime(2017, 6, 16, 15, 55, 0) + + +def test_load(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.describe_trial.return_value = {"ExperimentName": "experiment-name-value"} + trial_obj = _Trial.load(trial_name="name-value", sagemaker_session=sagemaker_session) + assert trial_obj.trial_name == "name-value" + assert trial_obj.experiment_name == "experiment-name-value" + client.describe_trial.assert_called_with(TrialName="name-value") + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", ExperimentName="experiment-name-value" + ) + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial.return_value = { + "Arn": "arn:aws:1234", + "TrialName": "name-value", + } + tags = [{"Key": "foo", "Value": "bar"}] + trial_obj = _Trial.create( + trial_name="name-value", + experiment_name="experiment-name-value", + sagemaker_session=sagemaker_session, + tags=tags, + ) + assert trial_obj.trial_name == "name-value" + client.create_trial.assert_called_with( + TrialName="name-value", + ExperimentName="experiment-name-value", + Tags=[{"Key": "foo", "Value": "bar"}], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial(sagemaker_session, trial_name="foo") + client.delete_trial.return_value = {} + obj.delete() + client.delete_trial.assert_called_with(TrialName="foo") + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _Trial( + sagemaker_session, + trial_name="foo", + experiment_name="whizz", + display_name="bar", + tags=[{"Key": "foo", "Value": "bar"}], + ) + client.update_trial.return_value = {} + obj.save() + + client.update_trial.assert_called_with( + TrialName="foo", + DisplayName="bar", + ) + + +def test_add_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.add_trial_component("foo") + client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.add_trial_component(tc) + client.associate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +def test_remove_trial_component(sagemaker_session): + client = sagemaker_session.sagemaker_client + trial = _Trial(sagemaker_session=sagemaker_session) + trial.trial_name = "bar" + trial.remove_trial_component("foo") + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName="foo" + ) + + tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + trial.remove_trial_component(tc) + client.disassociate_trial_component.assert_called_with( + TrialName="bar", TrialComponentName=tc.trial_component_name + ) + + +@patch("sagemaker.experiments.trial._Trial.load") +def test_load_or_create_when_exist(mock_load): + sagemaker_session = Session() + trial_name = "trial_name" + exp_name = "exp_name" + + # The trial exists and experiment matches + mock_load.return_value = _Trial( + trial_name=trial_name, + experiment_name=exp_name, + sagemaker_session=sagemaker_session, + ) + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + mock_load.assert_called_once_with(trial_name, sagemaker_session) + + # The trial exists but experiment does not match + mock_load.return_value = _Trial( + trial_name=trial_name, + exp_name="another_exp_name", + sagemaker_session=sagemaker_session, + ) + with pytest.raises(ValueError) as err: + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + assert "The given experiment_name {} does not match that in the loaded trial".format( + exp_name + ) in str(err) + + +@patch("sagemaker.experiments.trial._Trial.load") +@patch("sagemaker.experiments.trial._Trial.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + trial_name = "trial_name" + exp_name = "exp_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _Trial._load_or_create( + trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session + ) + + mock_create.assert_called_once_with( + trial_name=trial_name, + experiment_name=exp_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(**{}) + + +def test_list_trials_with_experiment_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list(_Trial.list(experiment_name="foo", sagemaker_session=sagemaker_session)) + client.list_trials.assert_called_with(ExperimentName="foo") + + +def test_list_trials_with_trial_component_name(sagemaker_session, datetime_obj): + client = sagemaker_session.sagemaker_client + client.list_trials.return_value = { + "TrialSummaries": [ + { + "TrialName": "trial-1", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + { + "TrialName": "trial-2", + "CreationTime": datetime_obj, + "LastModifiedTime": datetime_obj, + }, + ] + } + expected = [ + TrialSummary( + trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + TrialSummary( + trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj + ), + ] + assert expected == list( + _Trial.list(trial_component_name="tc-foo", sagemaker_session=sagemaker_session) + ) + client.list_trials.assert_called_with(TrialComponentName="tc-foo") diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py new file mode 100644 index 0000000000..c14663893e --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_trial_component.py @@ -0,0 +1,384 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime +import unittest.mock + +from unittest.mock import patch + +from sagemaker import Session +from sagemaker.experiments import _api_types +from sagemaker.experiments._api_types import ( + TrialComponentSearchResult, + Parent, + _TrialComponentStatusType, +) +from sagemaker.experiments.trial_component import _TrialComponent + + +def test_create(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + obj = _TrialComponent.create( + trial_component_name="foo", display_name="bar", sagemaker_session=sagemaker_session + ) + client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar") + assert "foo" == obj.trial_component_name + assert "bar" == obj.display_name + assert "bazz" == obj.trial_component_arn + + +def test_create_with_tags(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.create_trial_component.return_value = { + "TrialComponentArn": "bazz", + } + tags = [{"Key": "foo", "Value": "bar"}] + _TrialComponent.create( + trial_component_name="foo", + display_name="bar", + sagemaker_session=sagemaker_session, + tags=tags, + ) + client.create_trial_component.assert_called_with( + TrialComponentName="foo", DisplayName="bar", Tags=tags + ) + + +def test_load(sagemaker_session): + now = datetime.datetime.now(datetime.timezone.utc) + client = sagemaker_session.sagemaker_client + client.describe_trial_component.return_value = { + "TrialComponentArn": "A", + "TrialComponentName": "B", + "DisplayName": "C", + "Status": {"PrimaryStatus": _TrialComponentStatusType.InProgress.value, "Message": "D"}, + "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}}, + "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}}, + "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}}, + "Metrics": [ + { + "MetricName": "J", + "Count": 1, + "Min": 1.0, + "Max": 2.0, + "Avg": 3.0, + "StdDev": 4.0, + "SourceArn": "K", + "Timestamp": now, + } + ], + } + obj = _TrialComponent.load(trial_component_name="foo", sagemaker_session=sagemaker_session) + client.describe_trial_component.assert_called_with(TrialComponentName="foo") + assert "A" == obj.trial_component_arn + assert "B" == obj.trial_component_name + assert "C" == obj.display_name + assert ( + _api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="D" + ) + == obj.status + ) + assert {"E": 1.0, "F": "G"} == obj.parameters + assert {"H": _api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")} + assert { + "I": _api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain") + } + assert [ + _api_types.TrialComponentMetricSummary( + metric_name="J", + count=1, + min=1.0, + max=2.0, + avg=3.0, + std_dev=4.0, + source_arn="K", + timestamp=now, + ) + ] + + +def test_save(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent( + sagemaker_session, + trial_component_name="foo", + display_name="bar", + parameters_to_remove=["E"], + input_artifacts_to_remove=["F"], + output_artifacts_to_remove=["G"], + ) + client.update_trial_component.return_value = {} + obj.save() + + client.update_trial_component.assert_called_with( + TrialComponentName="foo", + DisplayName="bar", + Parameters={}, + ParametersToRemove=["E"], + InputArtifacts={}, + InputArtifactsToRemove=["F"], + OutputArtifacts={}, + OutputArtifactsToRemove=["G"], + ) + + +def test_delete(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + obj.delete() + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_delete_with_force_disassociate(sagemaker_session): + client = sagemaker_session.sagemaker_client + obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar") + client.delete_trial_component.return_value = {} + + client.list_trials.side_effect = [ + {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"}, + {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]}, + ] + + obj.delete(force_disassociate=True) + expected_calls = [ + unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), + unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), + ] + assert expected_calls == client.disassociate_trial_component.mock_calls + client.delete_trial_component.assert_called_with(TrialComponentName="foo") + + +def test_list(sagemaker_session): + start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) + end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) + creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) + last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) + + client = sagemaker_session.sagemaker_client + client.list_trial_components.side_effect = [ + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10) + ], + "NextToken": "100", + }, + { + "TrialComponentSummaries": [ + { + "TrialComponentName": "A" + str(i), + "TrialComponentArn": "B" + str(i), + "DisplayName": "C" + str(i), + "SourceArn": "D" + str(i), + "Status": { + "PrimaryStatus": _TrialComponentStatusType.InProgress.value, + "Message": "E" + str(i), + }, + "StartTime": start_time + datetime.timedelta(hours=i), + "EndTime": end_time + datetime.timedelta(hours=i), + "CreationTime": creation_time + datetime.timedelta(hours=i), + "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), + "LastModifiedBy": {}, + } + for i in range(10, 20) + ] + }, + ] + + expected = [ + _api_types.TrialComponentSummary( + trial_component_name="A" + str(i), + trial_component_arn="B" + str(i), + display_name="C" + str(i), + source_arn="D" + str(i), + status=_api_types.TrialComponentStatus( + primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i) + ), + start_time=start_time + datetime.timedelta(hours=i), + end_time=end_time + datetime.timedelta(hours=i), + creation_time=creation_time + datetime.timedelta(hours=i), + last_modified_time=last_modified_time + datetime.timedelta(hours=i), + last_modified_by={}, + ) + for i in range(20) + ] + result = list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + source_arn="foo", + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + assert expected == result + expected_calls = [ + unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), + unittest.mock.call( + NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo" + ), + ] + assert expected_calls == client.list_trial_components.mock_calls + + +def test_list_empty(sagemaker_session): + sagemaker_session.sagemaker_client.list_trial_components.return_value = { + "TrialComponentSummaries": [] + } + assert [] == list(_TrialComponent.list(sagemaker_session=sagemaker_session)) + + +def test_list_trial_components_call_args(sagemaker_session): + created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) + created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) + trial_name = "foo-trial" + experiment_name = "foo-experiment" + next_token = "thetoken" + max_results = 99 + + client = sagemaker_session.sagemaker_client + client.list_trial_components.return_value = {} + assert [] == list( + _TrialComponent.list( + sagemaker_session=sagemaker_session, + trial_name=trial_name, + experiment_name=experiment_name, + created_before=created_before, + created_after=created_after, + next_token=next_token, + max_results=max_results, + sort_by="CreationTime", + sort_order="Ascending", + ) + ) + + expected_calls = [ + unittest.mock.call( + TrialName="foo-trial", + ExperimentName="foo-experiment", + CreatedBefore=created_before, + CreatedAfter=created_after, + SortBy="CreationTime", + SortOrder="Ascending", + NextToken="thetoken", + MaxResults=99, + ) + ] + assert expected_calls == client.list_trial_components.mock_calls + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +def test_load_or_create_when_exist(mock_load, sagemaker_session): + tc_name = "tc_name" + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + assert is_existed + mock_load.assert_called_once_with( + tc_name, + sagemaker_session, + ) + + +@patch("sagemaker.experiments.trial_component._TrialComponent.load") +@patch("sagemaker.experiments.trial_component._TrialComponent.create") +def test_load_or_create_when_not_exist(mock_create, mock_load): + sagemaker_session = Session() + client = sagemaker_session.sagemaker_client + tc_name = "tc_name" + not_found_err = client.exceptions.ResourceNotFound( + error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}}, + operation_name="foo", + ) + mock_load.side_effect = not_found_err + + _, is_existed = _TrialComponent._load_or_create( + trial_component_name=tc_name, sagemaker_session=sagemaker_session + ) + + assert not is_existed + mock_create.assert_called_once_with( + trial_component_name=tc_name, + display_name=None, + tags=None, + sagemaker_session=sagemaker_session, + ) + + +def test_search(sagemaker_session): + client = sagemaker_session.sagemaker_client + client.search.return_value = { + "Results": [ + { + "TrialComponent": { + "TrialComponentName": "tc-1", + "TrialComponentArn": "arn::tc-1", + "DisplayName": "TC1", + "Parents": [ + { + "ExperimentName": "e-1", + "TrialName": "t-1", + }, + { + "ExperimentName": "e-2", + "TrialName": "t-2", + }, + ], + } + }, + { + "TrialComponent": { + "TrialComponentName": "tc-2", + "TrialComponentArn": "arn::tc-2", + "DisplayName": "TC2", + } + }, + ] + } + expected = [ + TrialComponentSearchResult( + trial_component_name="tc-1", + trial_component_arn="arn::tc-1", + display_name="TC1", + parents=[ + Parent(experiment_name="e-1", trial_name="t-1"), + Parent(experiment_name="e-2", trial_name="t-2"), + ], + ), + TrialComponentSearchResult( + trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2" + ), + ] + assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session)) diff --git a/tests/unit/sagemaker/experiments/test_utils.py b/tests/unit/sagemaker/experiments/test_utils.py new file mode 100644 index 0000000000..a63c96c0fe --- /dev/null +++ b/tests/unit/sagemaker/experiments/test_utils.py @@ -0,0 +1,36 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from src.sagemaker.experiments._utils import resolve_artifact_name, guess_media_type + + +def test_resolve_artifact_name(): + file_names = { + "a": "a", + "a.txt": "a.txt", + "b.": "b.", + ".c": ".c", + "/x/a/a.txt": "a.txt", + "/a/b/c.": "c.", + "./.a": ".a", + "../b.txt": "b.txt", + "~/a.txt": "a.txt", + "c/d.txt": "d.txt", + } + for file_name, artifact_name in file_names.items(): + assert artifact_name == resolve_artifact_name(file_name) + + +def test_guess_media_type(): + assert "text/plain" == guess_media_type("foo.txt") diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index c391d45382..0088e34c58 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -48,6 +48,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 2e7576421f..fea80b7ea9 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -56,6 +56,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index af46cf4360..d35c0a51dd 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -52,6 +52,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 5aef9316da..7645c4fe23 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 7517f3a641..1ce58a19b4 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -50,6 +50,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/sagemaker/utilities/test_search_expression.py b/tests/unit/sagemaker/utilities/test_search_expression.py new file mode 100644 index 0000000000..98a52a992a --- /dev/null +++ b/tests/unit/sagemaker/utilities/test_search_expression.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import pytest + +from sagemaker.utilities.search_expression import ( + Filter, + Operator, + NestedFilter, + SearchExpression, + BooleanOperator, +) + + +def test_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + + assert { + "Name": "learning_rate", + "Operator": "Equals", + "Value": "0.1", + } == search_filter.to_boto() + + +def test_partial_filters(): + search_filter = Filter(name="learning_rate") + + assert {"Name": "learning_rate"} == search_filter.to_boto() + + +def test_nested_filters(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + filters = [search_filter] + nested_filters = NestedFilter(property_name="hyper_param", filters=filters) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } == nested_filters.to_boto() + + +def test_search_expression(): + search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1") + nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter]) + search_expression = SearchExpression( + filters=[search_filter], + nested_filters=[nested_filter], + sub_expressions=[], + boolean_operator=BooleanOperator.AND, + ) + + assert { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedFilters": [ + { + "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}], + "NestedPropertyName": "hyper_param", + } + ], + "SubExpressions": [], + "Operator": "And", + } == search_expression.to_boto() + + +def test_illegal_search_expression(): + with pytest.raises( + ValueError, match="You must specify at least one subexpression, filter, or nested filter" + ): + SearchExpression() diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py index feadaa03dc..54b354b71e 100644 --- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py +++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py @@ -16,10 +16,6 @@ import re import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.clarify import ( DataConfig, BiasConfig, @@ -50,46 +46,6 @@ _S3_ANALYSIS_CONFIG_OUTPUT_PATH = "s3://my_bucket/analysis_cfg_output" -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_DEFAULT_BUCKET, - ) - - _expected_data_bias_dsl = { "Name": "DataBiasCheckStep", "Type": "ClarifyCheck", diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py index 6f0be2ccca..a36207b241 100644 --- a/tests/unit/sagemaker/workflow/test_entities.py +++ b/tests/unit/sagemaker/workflow/test_entities.py @@ -19,9 +19,6 @@ from enum import Enum -from mock.mock import Mock, PropertyMock - -import sagemaker from sagemaker.workflow.condition_step import ConditionStep from sagemaker.workflow.conditions import ConditionGreaterThan from sagemaker.workflow.entities import ( @@ -58,46 +55,6 @@ def custom_entity_list(): return [CustomEntity(1), CustomEntity(2)] -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value="role") - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name="us-west-2") - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket="my-bucket", - ) - - def test_entity(custom_entity): request_struct = {"foo": 1} assert custom_entity.to_request() == request_struct diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py index b60e2de8fa..dc104d71df 100644 --- a/tests/unit/sagemaker/workflow/test_quality_check_step.py +++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py @@ -15,10 +15,6 @@ import json import pytest -import sagemaker - -from mock import Mock, PropertyMock - from sagemaker.model_monitor import DatasetFormat from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.pipeline import Pipeline @@ -31,49 +27,7 @@ from sagemaker.workflow.steps import CacheConfig from sagemaker.workflow.check_job_config import CheckJobConfig -_REGION = "us-west-2" _ROLE = "DummyRole" -_BUCKET = "my-bucket" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) _expected_data_quality_dsl = { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 9887d43078..ba712d11d7 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -16,15 +16,10 @@ import json import pytest -import sagemaker import os import warnings -from mock import ( - Mock, - PropertyMock, - patch, -) +from mock import patch from sagemaker.debugger import ProfilerConfig from sagemaker.estimator import Estimator @@ -94,46 +89,6 @@ def create_predictor(self, endpoint_name): return Predictor(endpoint_name, self.sagemaker_session) -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) - - @pytest.fixture def script_processor(sagemaker_session): return ScriptProcessor( diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 82b154317d..44b5818fc8 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -225,6 +225,9 @@ def test_fit_ndarray(time, sagemaker_session): assert mock_object.put.call_count == 4 + called_args = sagemaker_session.train.call_args + assert not called_args[1]["experiment_config"] + def test_fit_pass_experiment_config(sagemaker_session): kwargs = dict(COMMON_ARGS) @@ -239,12 +242,18 @@ def test_fit_pass_experiment_config(sagemaker_session): labels = [99, 85, 87, 2] pca.fit( pca.record_set(np.array(train), np.array(labels)), - experiment_config={"ExperimentName": "exp"}, + experiment_config={ + "ExperimentName": "exp", + "RunName": "rn", + }, ) called_args = sagemaker_session.train.call_args - assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"} + assert called_args[1]["experiment_config"] == { + "ExperimentName": "exp", + "RunName": "rn", + } def test_build_shards(): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 34e6a43fcf..868da88d78 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2489,7 +2489,12 @@ def test_start_new(sagemaker_session): hyperparameters=hyperparameters, ) - exp_config = {"ExperimentName": "exp", "TrialName": "t", "TrialComponentDisplayName": "tc"} + exp_config = { + "ExperimentName": "exp", + "TrialName": "t", + "TrialComponentDisplayName": "tc", + "RunName": "rn", + } started_training_job = training_job.start_new(estimator, inputs, experiment_config=exp_config) called_args = sagemaker_session.train.call_args @@ -2680,6 +2685,7 @@ def test_unsupported_type_in_dict(): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } } ) @@ -2884,6 +2890,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session): "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", }, ) diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 99b0e839b7..9ba3e17ff3 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -62,6 +62,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"} diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 082f699d63..c8aad13774 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}} diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 4efc2e5bf8..2035636e76 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -49,6 +49,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d7c94470f5..ec4a21cbc9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -588,11 +588,16 @@ def test_user_agent_injected(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi(boto_session): @@ -607,10 +612,14 @@ def test_user_agent_injected_with_nbi(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_user_agent_injected_with_nbi_ioerror(boto_session): @@ -625,11 +634,16 @@ def test_user_agent_injected_with_nbi_ioerror(boto_session): assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent + assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent assert ( "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_runtime_client._client_config.user_agent ) + assert ( + "AWS-SageMaker-Notebook-Instance" + not in sess.sagemaker_metrics_client._client_config.user_agent + ) def test_training_input_all_defaults(): @@ -700,6 +714,7 @@ def test_training_input_all_arguments(): "ExperimentName": "dummyExp", "TrialName": "dummyT", "TrialComponentDisplayName": "dummyTC", + "RunName": "dummyRN", } MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60} diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 13cc755336..c3e984e0b7 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -51,6 +51,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 0eb81be584..8bcbed41c2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -25,10 +25,12 @@ from boto3 import exceptions import botocore import pytest -from mock import call, patch, Mock, MagicMock +from mock import call, patch, Mock, MagicMock, PropertyMock import sagemaker +from sagemaker.experiments._run_context import _RunContext from sagemaker.session_settings import SessionSettings +from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -795,3 +797,63 @@ def test_start_waiting(capfd): out, _ = capfd.readouterr() assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out + + +def test_retry_with_backoff(): + callable_func = Mock() + + # Invalid input + with pytest.raises(ValueError) as value_err: + retry_with_backoff(callable_func, 0) + assert "The num_attempts must be >= 1" in str(value_err) + callable_func.assert_not_called() + + # All retries fail + run_err_msg = "Test Retry Error" + callable_func.side_effect = RuntimeError(run_err_msg) + with pytest.raises(RuntimeError) as run_err: + retry_with_backoff(callable_func, 2) + assert run_err_msg in str(run_err) + + # One retry passes + func_return_val = "Test Return" + callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val] + assert retry_with_backoff(callable_func, 2) == func_return_val + + # No retry + callable_func.side_effect = None + callable_func.return_value = func_return_val + assert retry_with_backoff(callable_func, 2) == func_return_val + + +def test_check_and_get_run_experiment_config(): + supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"} + run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"} + + # No user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg1 = check_and_get_run_experiment_config(None) + assert exp_cfg1 is None + + # With user supplied exp config and no current Run + assert not _RunContext.get_current_run() + exp_cfg2 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg2 == supplied_exp_cfg + + run = Mock() + type(run).experiment_config = PropertyMock(return_value=run_exp_cfg) + _RunContext.add_run_object(run) + + try: + # No user supplied exp config and with current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg3 = check_and_get_run_experiment_config(None) + assert exp_cfg3 == run_exp_cfg + + # With user supplied exp config and current Run + assert _RunContext.get_current_run().experiment_config == run_exp_cfg + exp_cfg4 = check_and_get_run_experiment_config(supplied_exp_cfg) + assert exp_cfg4 == supplied_exp_cfg + finally: + # Clean up the global static variable in case it affects other tests + _RunContext.drop_current_run() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index 82f27c19ae..d58c4992cd 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -54,6 +54,7 @@ "ExperimentName": "exp", "TrialName": "trial", "TrialComponentDisplayName": "tc", + "RunName": "rn", } From 1cbfc8389f75323d779e560d12b15f163a23c7af Mon Sep 17 00:00:00 2001 From: Tejas Chumbalkar <34728580+tejaschumbalkar@users.noreply.github.com> Date: Wed, 14 Dec 2022 10:22:22 -0800 Subject: [PATCH 42/58] feature: Add support for TF2.9.2 training images (#3178) --- src/sagemaker/fw_utils.py | 1 + src/sagemaker/image_uri_config/tensorflow.json | 4 ++-- tests/unit/test_fw_utils.py | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 5efe530396..3ba918ea2c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -103,6 +103,7 @@ "2.8.0", "2.9", "2.9.1", + "2.9.2", "2.10", "2.10.0", ], diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index a900aa4fe5..6bb36057fa 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -1820,7 +1820,7 @@ "2.6": "2.6.3", "2.7": "2.7.1", "2.8": "2.8.0", - "2.9": "2.9.1", + "2.9": "2.9.2", "2.10": "2.10.0" }, "versions": { @@ -3273,7 +3273,7 @@ }, "repository": "tensorflow-training" }, - "2.9.1": { + "2.9.2": { "py_versions": [ "py39" ], diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 667d115d58..4654abb928 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -883,6 +883,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7", "py38", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.9", "py39", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled), @@ -915,6 +916,7 @@ def test_validate_smdataparallel_args_not_raises(): ("ml.p3.16xlarge", "tensorflow", "2.7.1", "py38", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.9.1", "py39", smdataparallel_enabled_custom_mpi), + ("ml.p3.16xlarge", "tensorflow", "2.9.2", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "tensorflow", "2.10.0", "py39", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi), ("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi), From 881caecd9a45d0facde0913baa74895938c5e788 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 15 Dec 2022 01:19:26 +0000 Subject: [PATCH 43/58] prepare release v2.123.0 --- CHANGELOG.md | 7 +++++++ VERSION | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de20a8a0df..a05b64c96f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v2.123.0 (2022-12-15) + +### Features + + * Add support for TF2.9.2 training images + * Add SageMaker Experiment + ## v2.122.0 (2022-12-14) ### Features diff --git a/VERSION b/VERSION index 6d7f044fa2..bef06dbf6d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.122.1.dev0 +2.123.0 From d543604609f3d0b1f0856d8346c8ecf271203432 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 15 Dec 2022 01:19:27 +0000 Subject: [PATCH 44/58] update development version to v2.123.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index bef06dbf6d..ea5085760e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.123.0 +2.123.1.dev0 From eef679cfc167827b14e47d0e1991e274a16e1ed4 Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Wed, 14 Dec 2022 19:55:56 -0800 Subject: [PATCH 45/58] feature: Added doc update for dataset builder (#3539) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add list_feature_groups API (#647) * feat: Feature/get record api (#650) Co-authored-by: Eric Zou * Add delete_record API (#664) * feat: Add DatasetBuilder class (#667) Co-authored-by: Eric Zou * feat: Add to_csv method in DatasetBuilder (#699) * feat: Add pandas.Dataframe as base case (#708) * feat: Add with_feature_group method in DatasetBuilder (#726) * feat: Handle merge and timestamp filters (#727) * feat: Add to_dataframe method in DatasetBuilder (#729) * Address TODOs (#731) * Unit test for DatasetBuilder (#734) * fix: Fix list_feature_groups max_results (#744) * Add integration tests for create_dataset (#743) * feature: Aggregate commits * fix: as_of, event_range, join, default behavior and duplicates… (#764) * Bug fixed - as_of, event_range, join, default behavior and duplicates and tests Bugs: 1. as_of was not working properly on deleted events 2. Same event_time_range 3. Join was not working when including feature names 4. Default sql was returning only most recent, whereas it should all excluding duplicates 5. Include duplicates was not return all non-deleted data 6. instanceof(dataframe) case was also applied to non-df cases while join 7. Include column was returning unnecessary columns. * Fix on pylint error * Fix on include_duplicated_records for panda data frames * Fix format issue for black * Bug fixed related to line break * Bug fix related to dataframe and inclde_deleted_record and include_duplicated_record * Addressed comments and code refactored * changed to_csv to to_csv_file and added error messages for query limit and recent record limit * Revert a change which was not intended * Resolved the leak of feature group deletion in integration test * Added doc update for dataset builder * Fix the issue in doc Co-authored-by: Yiming Zou Co-authored-by: Brandon Chatham Co-authored-by: Eric Zou Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> --- doc/api/prep_data/feature_store.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 1980a0b069..0e9bf25586 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -72,3 +72,11 @@ Inputs .. autoclass:: sagemaker.feature_store.inputs.FeatureValue :members: :show-inheritance: + + +Dataset Builder +*************** + +.. autoclass:: sagemaker.feature_store.dataset_builder.DatasetBuilder + :members: + :show-inheritance: From 019d5a4b232cd4d287dff35c6a8ba9681ed4c0ca Mon Sep 17 00:00:00 2001 From: mariumof <99500633+mariumof@users.noreply.github.com> Date: Thu, 15 Dec 2022 12:38:55 -0800 Subject: [PATCH 46/58] feature: Add disable_profiler field in config and propagate changes (#3523) Co-authored-by: Marius Moisescu --- src/sagemaker/debugger/profiler_config.py | 4 + src/sagemaker/estimator.py | 23 +- .../integ/sagemaker/workflow/test_workflow.py | 4 - tests/integ/test_profiler.py | 40 --- .../sagemaker/huggingface/test_estimator.py | 8 +- .../sagemaker/tensorflow/test_estimator.py | 8 +- .../test_huggingface_pytorch_compiler.py | 8 +- .../test_huggingface_tensorflow_compiler.py | 8 +- .../test_pytorch_compiler.py | 12 +- .../test_tensorflow_compiler.py | 8 +- .../workflow/test_step_collections.py | 4 + tests/unit/sagemaker/workflow/test_steps.py | 3 +- .../sagemaker/workflow/test_training_step.py | 25 -- tests/unit/sagemaker/workflow/test_utils.py | 2 + tests/unit/test_chainer.py | 8 +- tests/unit/test_estimator.py | 249 ++++++++++++------ tests/unit/test_mxnet.py | 8 +- tests/unit/test_pytorch.py | 8 +- tests/unit/test_rl.py | 8 +- tests/unit/test_sklearn.py | 8 +- tests/unit/test_xgboost.py | 8 +- 21 files changed, 211 insertions(+), 243 deletions(-) diff --git a/src/sagemaker/debugger/profiler_config.py b/src/sagemaker/debugger/profiler_config.py index 3d4a24e8d1..561de38b9f 100644 --- a/src/sagemaker/debugger/profiler_config.py +++ b/src/sagemaker/debugger/profiler_config.py @@ -32,6 +32,7 @@ def __init__( s3_output_path: Optional[Union[str, PipelineVariable]] = None, system_monitor_interval_millis: Optional[Union[int, PipelineVariable]] = None, framework_profile_params: Optional[FrameworkProfile] = None, + disable_profiler: Optional[Union[str, PipelineVariable]] = False, ): """Initialize a ``ProfilerConfig`` instance. @@ -78,6 +79,7 @@ class and SageMaker Framework estimators. self.s3_output_path = s3_output_path self.system_monitor_interval_millis = system_monitor_interval_millis self.framework_profile_params = framework_profile_params + self.disable_profiler = disable_profiler def _to_request_dict(self): """Generate a request dictionary using the parameters provided when initializing the object. @@ -91,6 +93,8 @@ def _to_request_dict(self): if self.s3_output_path is not None: profiler_config_request["S3OutputPath"] = self.s3_output_path + profiler_config_request["DisableProfiler"] = self.disable_profiler + if self.system_monitor_interval_millis is not None: profiler_config_request[ "ProfilingIntervalInMilliseconds" diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index e3b06950aa..8ed9b724a5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -938,26 +938,29 @@ def _prepare_collection_configs(self): def _prepare_profiler_for_training(self): """Set necessary values and do basic validations in profiler config and profiler rules. - When user explicitly set rules to an empty list, default profiler rule won't be enabled. - Default profiler rule will be enabled in supported regions when either: - 1. user doesn't specify any rules, i.e., rules=None; or - 2. user only specify debugger rules, i.e., rules=[Rule.sagemaker(...)] + No default profiler rule will be used. The user needs to specify rules explicitly """ if self.disable_profiler: - if self.profiler_config: - raise RuntimeError("profiler_config cannot be set when disable_profiler is True.") + if self.profiler_config and not self.profiler_config.disable_profiler: + raise RuntimeError( + "profiler_config.disable_profiler cannot be False" + + " when disable_profiler is True." + ) if self.profiler_rules: raise RuntimeError("ProfilerRule cannot be set when disable_profiler is True.") elif _region_supports_profiler(self.sagemaker_session.boto_region_name): if self.profiler_config is None: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) if self.rules is None or (self.rules and not self.profiler_rules): - self.profiler_rules = [get_default_profiler_rule()] + self.profiler_rules = [] if self.profiler_config and not self.profiler_config.s3_output_path: self.profiler_config.s3_output_path = self.output_path self.profiler_rule_configs = self._prepare_profiler_rules() + # if profiler_config is still None, it means the job has profiler disabled + if self.profiler_config is None: + self.profiler_config = ProfilerConfig(disable_profiler=True) def _prepare_profiler_rules(self): """Set any necessary values in profiler rules, if they are provided.""" @@ -1048,7 +1051,7 @@ def latest_job_profiler_artifacts_path(self): error_message="""Cannot get the profiling output artifacts path. The Estimator is not associated with a training job.""" ) - if self.profiler_config is not None: + if self.profiler_config is not None and not self.profiler_config.disable_profiler: return os.path.join( self.profiler_config.s3_output_path, self.latest_training_job.name, @@ -1895,8 +1898,8 @@ def enable_default_profiling(self): else: self.profiler_config = ProfilerConfig(s3_output_path=self.output_path) - self.profiler_rules = [get_default_profiler_rule()] - self.profiler_rule_configs = self._prepare_profiler_rules() + self.profiler_rules = [] + self.profiler_rule_configs = [] _TrainingJob.update( self, self.profiler_rule_configs, self.profiler_config._to_request_dict() diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 44f4e2d26e..bd24b653ae 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1269,8 +1269,6 @@ def test_caching_behavior( # create pipeline pipeline.create(role) definition = json.loads(pipeline.definition()) - # delete profiler config for assertions as it will contain a timestamp - del definition["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] # verify input path expected_abalone_input_path = f"{pipeline_name}/{step_process.name}" f"/input/abalone_data" @@ -1295,7 +1293,6 @@ def test_caching_behavior( # verify no changes definition2 = json.loads(pipeline.definition()) - del definition2["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition == definition2 # add dummy file to source_dir @@ -1306,7 +1303,6 @@ def test_caching_behavior( # verify changes definition3 = json.loads(pipeline.definition()) - del definition3["Steps"][1]["Arguments"]["ProfilerRuleConfigurations"] assert definition != definition3 finally: diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index bddd53e20c..7d3fdb2d7b 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -13,7 +13,6 @@ from __future__ import absolute_import import os -import re import time import uuid @@ -22,7 +21,6 @@ from sagemaker.debugger import ( DebuggerHookConfig, FrameworkProfile, - get_rule_container_image_uri, ProfilerConfig, ProfilerRule, Rule, @@ -93,8 +91,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert ( job_description["ProfilerConfig"] == ProfilerConfig( @@ -103,13 +99,6 @@ def test_mxnet_with_default_profiler_config_and_profiler_rule( ) assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - with pytest.raises(ValueError) as error: mx.enable_default_profiling() assert "Debugger monitoring is already enabled." in str(error) @@ -155,18 +144,9 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilerConfig") == profiler_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) mx.update_profiler( @@ -178,13 +158,6 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config( assert job_description["ProfilerConfig"]["S3OutputPath"] == profiler_config.s3_output_path assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500 - profiler_report_rule_config = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_report_rule_config["RuleConfigurationName"]) - assert profiler_report_rule_config["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_report_rule_config["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( sagemaker_session, @@ -225,8 +198,6 @@ def test_mxnet_with_built_in_profiler_rule_with_custom_parameters( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description.get("ProfilingStatus") == "Enabled" assert ( job_description.get("ProfilerConfig") @@ -298,8 +269,6 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics( ) job_description = mx.latest_training_job.describe() - if "DisableProfiler" in job_description["ProfilerConfig"]: - job_description["ProfilerConfig"].pop("DisableProfiler") assert job_description["ProfilerConfig"] == profiler_config._to_request_dict() assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict() assert job_description.get("ProfilingStatus") == "Enabled" @@ -387,13 +356,6 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics( == updated_framework_profile.profiling_parameters ) - profiler_rule_configuration = job_description.get("ProfilerRuleConfigurations")[0] - assert re.match(r"ProfilerReport-\d*", profiler_rule_configuration["RuleConfigurationName"]) - assert profiler_rule_configuration["RuleEvaluatorImage"] == get_rule_container_image_uri( - mx.sagemaker_session.boto_region_name - ) - assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"} - def test_mxnet_with_disable_profiler_then_enable_default_profiling( sagemaker_session, @@ -431,12 +393,10 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling( ) job_description = mx.latest_training_job.describe() - assert job_description.get("ProfilerConfig") is None assert job_description.get("ProfilerRuleConfigurations") is None assert job_description.get("ProfilingStatus") == "Disabled" _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) - mx.enable_default_profiling() job_description = mx.latest_training_job.describe() diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 0088e34c58..072eefeb83 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -143,14 +143,8 @@ def _create_train_job(version, base_framework_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index fea80b7ea9..771b18b35a 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -136,14 +136,8 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd "metric_definitions": None, "environment": None, "experiment_config": None, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index d35c0a51dd..656730a47c 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -145,14 +145,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index 7645c4fe23..c3684ac649 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -143,14 +143,8 @@ def _create_train_job( "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 0fe2402695..068bb4e4b9 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -137,14 +137,10 @@ def _create_train_job(version, instance_type, training_compiler_config, instance "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], - "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "profiler_config": { + "DisableProfiler": False, + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, } diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 1ce58a19b4..a5c14b1626 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -145,14 +145,8 @@ def _create_train_job(framework_version, instance_type, training_compiler_config "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 2bf47a79d0..95738c99ca 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -796,6 +796,7 @@ def test_register_model_with_model_repack_with_estimator( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "dependencies": f'"{dummy_requirements}"', @@ -923,6 +924,7 @@ def test_register_model_with_model_repack_with_model(model, model_metrics, drift "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "inference_script": '"dummy_script.py"', "model_archive": '"s3://my-bucket/model.tar.gz"', @@ -1052,6 +1054,7 @@ def test_register_model_with_model_repack_with_pipeline_model( "CollectionConfigurations": [], "S3OutputPath": f"s3://{BUCKET}/", }, + "ProfilerConfig": {"DisableProfiler": True}, "HyperParameters": { "dependencies": "null", "inference_script": '"dummy_script.py"', @@ -1243,6 +1246,7 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator): "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + "sagemaker-scikit-learn:0.23-1-cpu-py3", }, + "ProfilerConfig": {"DisableProfiler": True}, "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index ba712d11d7..f2046cc00f 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -329,6 +329,7 @@ def test_training_step_base_estimator(sagemaker_session): "CollectionConfigurations": [], }, "ProfilerConfig": { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 500, "S3OutputPath": {"Std:Join": {"On": "/", "Values": ["s3:/", "a", "b"]}}, }, @@ -438,7 +439,7 @@ def test_training_step_tensorflow(sagemaker_session): "sagemaker_instance_type": {"Get": "Parameters.InstanceType"}, "sagemaker_distributed_dataparallel_custom_mpi_options": '""', }, - "ProfilerConfig": {"S3OutputPath": "s3://my-bucket/"}, + "ProfilerConfig": {"DisableProfiler": False, "S3OutputPath": "s3://my-bucket/"}, }, "CacheConfig": {"Enabled": True, "ExpireAfter": "PT1H"}, } diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 3e8b57b069..7f8e6b0c62 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -401,10 +401,6 @@ def test_training_step_with_estimator( } step_definition = json.loads(pipeline.definition())["Steps"][0] - # delete profiler rule configurations because of timestamp collision - del step_definition["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_definition == { "Name": "MyTrainingStep", "Description": "TrainingStep description", @@ -428,7 +424,6 @@ def test_training_step_with_estimator( # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_definition == step_def2 @@ -537,10 +532,6 @@ def test_training_step_with_framework_estimator( del expected_step_args["OutputDataConfig"]["S3OutputPath"] del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del expected_step_args["ProfilerRuleConfigurations"] - if "sagemaker_s3_output" in step_args["HyperParameters"]: del expected_step_args["HyperParameters"]["sagemaker_s3_output"] del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] @@ -555,7 +546,6 @@ def test_training_step_with_framework_estimator( step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] assert step_def == step_def2 @@ -608,10 +598,6 @@ def test_training_step_with_framework_estimator_local_code( del expected_step_args["OutputDataConfig"]["S3OutputPath"] del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del expected_step_args["ProfilerRuleConfigurations"] - if "sagemaker_s3_output" in step_args["HyperParameters"]: del expected_step_args["HyperParameters"]["sagemaker_s3_output"] del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"] @@ -626,7 +612,6 @@ def test_training_step_with_framework_estimator_local_code( step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]: del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"] assert step_def == step_def2 @@ -701,10 +686,6 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_def == { "Name": "MyTrainingStep", "Type": "Training", @@ -714,7 +695,6 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_def == step_def2 @@ -789,10 +769,6 @@ def test_training_step_with_algorithm_base_local_code( del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - # delete profiler rule configurations because of timestamp collision - del step_def["Arguments"]["ProfilerRuleConfigurations"] - del step_args["ProfilerRuleConfigurations"] - assert step_def == { "Name": "MyTrainingStep", "Type": "Training", @@ -802,7 +778,6 @@ def test_training_step_with_algorithm_base_local_code( # test idempotency step_def2 = json.loads(pipeline.definition())["Steps"][0] del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] - del step_def2["Arguments"]["ProfilerRuleConfigurations"] assert step_def == step_def2 diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index c8d86c5866..d1b81f3148 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -107,6 +107,7 @@ def test_repack_model_step(estimator): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", @@ -188,6 +189,7 @@ def test_repack_model_step_with_source_dir(estimator, source_dir): } ], "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, + "ProfilerConfig": {"DisableProfiler": True}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 7cc973440f..eca4a9bf80 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -150,14 +150,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 868da88d78..8b771f9184 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -25,7 +25,10 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, PropertyMock from sagemaker.huggingface.estimator import HuggingFace -from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME +from sagemaker.jumpstart.constants import ( + JUMPSTART_BUCKET_NAME_SET, + JUMPSTART_RESOURCE_BASE_NAME, +) from sagemaker.jumpstart.enums import JumpStartTag import sagemaker.local @@ -106,7 +109,11 @@ "training_steps": "100", }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", - "ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.c4.xlarge", + }, "EnableNetworkIsolation": False, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", @@ -143,7 +150,10 @@ } MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir" MOCKED_PIPELINE_CONFIG = _PipelineConfig( - "test-pipeline", "test-training-step", "code-hash-0123456789", "config-hash-0123456789" + "test-pipeline", + "test-training-step", + "code-hash-0123456789", + "config-hash-0123456789", ) @@ -247,7 +257,9 @@ def pipeline_session(): session_mock.resource.return_value = resource_mock session_mock.client.return_value = client_mock return PipelineSession( - boto_session=session_mock, sagemaker_client=client_mock, default_bucket=BUCKET_NAME + boto_session=session_mock, + sagemaker_client=client_mock, + default_bucket=BUCKET_NAME, ) @@ -322,7 +334,11 @@ def test_framework_all_init_args(sagemaker_session): }, "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], "encrypt_inter_container_traffic": True, - "environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}, + "environment": { + "env_key1": "env_val1", + "env_key2": "env_val2", + "env_key3": "env_val3", + }, "experiment_config": None, "checkpoint_s3_uri": "s3://bucket/checkpoint", "checkpoint_local_path": "file://local/checkpoint", @@ -379,7 +395,8 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -405,18 +422,23 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } def test_framework_with_debugger_and_custom_rule(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) debugger_custom_rule = Rule.custom( name="CustomRule", @@ -536,7 +558,8 @@ def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session): def test_framework_with_only_debugger_hook_config(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] + s3_output_path="s3://output", + collection_configs=[CollectionConfig(name="weights")], ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -574,15 +597,9 @@ def test_framework_without_debugger_and_profiler(time, sagemaker_session): } assert "debugger_rule_configs" not in args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_debugger_and_profiler_rules(sagemaker_session): @@ -591,7 +608,8 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} + name="losses", + parameters={"train.save_interval": "50", "eval.save_interval": "10"}, ) ], ) @@ -639,18 +657,25 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, + "CollectionParameters": { + "train.save_interval": "50", + "eval.save_interval": "10", + }, } ], } assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ { "RuleConfigurationName": "CustomProfilerReportRule", "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport", "CPUBottleneck_threshold": "90"}, + "RuleParameters": { + "rule_to_invoke": "ProfilerReport", + "CPUBottleneck_threshold": "90", + }, }, { "InstanceType": "c4.4xlarge", @@ -679,6 +704,7 @@ def test_framework_with_only_profiler_rule_specified(sagemaker_session): sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } assert args["profiler_rule_configs"] == [ @@ -711,16 +737,10 @@ def test_framework_with_profiler_config_without_s3_output_path(time, sagemaker_s sagemaker_session.train.assert_called_once() _, args = sagemaker_session.train.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), "ProfilingIntervalInMilliseconds": 1000, } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @pytest.mark.parametrize("region", PROFILER_UNSUPPORTED_REGIONS) @@ -745,7 +765,9 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): f.fit("s3://mydata") sms.train.assert_called_once() _, args = sms.train.call_args - assert args.get("profiler_config") is None + # assert args.get("profiler_config") == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True assert args.get("profiler_rule_configs") is None @@ -865,7 +887,10 @@ def test_framework_with_profiler_config_and_profiler_disabled(sagemaker_session) disable_profiler=True, ) f.fit("s3://mydata") - assert "profiler_config cannot be set when disable_profiler is True." in str(error) + # assert "profiler_config cannot be set when disable_profiler is True." in str(error) + assert "profiler_config.disable_profiler cannot be False when disable_profiler is True." in str( + error + ) def test_framework_with_profiler_rule_and_profiler_disabled(sagemaker_session): @@ -927,15 +952,9 @@ def test_framework_with_enabling_default_profiling( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] @patch("time.time", return_value=TIME) @@ -960,15 +979,9 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path( sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "S3OutputPath": "s3://custom/", } - assert args["profiler_rule_configs"] == [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ] def test_framework_with_disabling_profiling_when_profiler_is_already_disabled( @@ -1001,7 +1014,9 @@ def test_framework_with_disabling_profiling(sagemaker_session, training_job_desc f.disable_profiling() sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"DisableProfiler": True} + # assert args["profiler_config"] == {"DisableProfiler": True} + # temporarily check if "DisableProfiler" flag is true until s3_output is changed to optional in service + assert args.get("profiler_config")["DisableProfiler"] is True def test_framework_with_update_profiler_when_no_training_job(sagemaker_session): @@ -1058,6 +1073,7 @@ def test_framework_with_update_profiler_config(sagemaker_session): sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, } assert "profiler_rule_configs" not in args @@ -1086,7 +1102,7 @@ def test_framework_with_update_profiler_report_rule(sagemaker_session): "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, } ] - assert "profiler_config" not in args + assert args["profiler_config"]["DisableProfiler"] is False def test_framework_with_disable_framework_metrics(sagemaker_session): @@ -1101,11 +1117,16 @@ def test_framework_with_disable_framework_metrics(sagemaker_session): f.update_profiler(disable_framework_metrics=True) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingParameters": {}} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingParameters": {}, + } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_system_metrics(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_system_metrics( + sagemaker_session, +): f = DummyFramework( entry_point=SCRIPT_PATH, role=ROLE, @@ -1118,13 +1139,16 @@ def test_framework_with_disable_framework_metrics_and_update_system_metrics(sage sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args assert args["profiler_config"] == { + "DisableProfiler": False, "ProfilingIntervalInMilliseconds": 1000, "ProfilingParameters": {}, } assert "profiler_rule_configs" not in args -def test_framework_with_disable_framework_metrics_and_update_framework_params(sagemaker_session): +def test_framework_with_disable_framework_metrics_and_update_framework_params( + sagemaker_session, +): with pytest.raises(ValueError) as error: f = DummyFramework( entry_point=SCRIPT_PATH, @@ -1160,7 +1184,10 @@ def test_framework_with_update_profiler_config_and_profiler_rule(sagemaker_sessi f.update_profiler(rules=[profiler_custom_rule], system_monitor_interval_millis=1000) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == {"ProfilingIntervalInMilliseconds": 1000} + assert args["profiler_config"] == { + "DisableProfiler": False, + "ProfilingIntervalInMilliseconds": 1000, + } assert args["profiler_rule_configs"] == [ { "InstanceType": "c4.4xlarge", @@ -1659,7 +1686,10 @@ def test_start_new_wait_called(strftime, sagemaker_session): def test_attach_framework(sagemaker_session, training_job_description): - training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} + training_job_description["VpcConfig"] = { + "Subnets": ["foo"], + "SecurityGroupIds": ["bar"], + } training_job_description["EnableNetworkIsolation"] = True framework_estimator = DummyFramework.attach( @@ -1753,7 +1783,8 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag( def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): base_job_name = "neo" framework_estimator = DummyFramework.attach( - training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session + training_job_name=utils.name_from_base("neo"), + sagemaker_session=sagemaker_session, ) assert framework_estimator.base_job_name == base_job_name @@ -1948,7 +1979,8 @@ def test_git_support_bad_repo_url_format(sagemaker_session): @patch( "sagemaker.git_utils.git_clone_repo", side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" + returncode=1, + cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir", ), ) def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): @@ -1973,7 +2005,11 @@ def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): ), ) def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + git_config = { + "repo": GIT_REPO, + "branch": "branch-that-does-not-exist", + "commit": COMMIT, + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -1994,7 +2030,11 @@ def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): ), ) def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + git_config = { + "repo": GIT_REPO, + "branch": BRANCH, + "commit": "commit-sha-that-does-not-exist", + } fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -2137,7 +2177,11 @@ def test_git_support_with_token_2fa(git_clone_repo, sagemaker_session): }, ) def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2159,7 +2203,11 @@ def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session) ), ) def test_git_support_ssh_passphrase_required(git_clone_repo, sagemaker_session): - git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO_SSH, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + } entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2457,7 +2505,9 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem ) create_model.assert_called_with( - vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True + vpc_config_override=new_vpc_config, + model_kms_key=kms_key, + enable_network_isolation=True, ) assert transformer.strategy == strategy @@ -2635,14 +2685,7 @@ def test_unsupported_type_in_dict(): "input_config": None, "input_mode": "File", "output_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_config": {"S3OutputPath": OUTPUT_PATH}, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], + "profiler_config": {"DisableProfiler": False, "S3OutputPath": OUTPUT_PATH}, "resource_config": { "InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE, @@ -2749,7 +2792,11 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): @patch("sagemaker.estimator.name_from_base") def test_fit_deploy_tags(name_from_base, sagemaker_session): estimator = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) estimator.fit() @@ -3197,7 +3244,10 @@ def test_generic_training_job_analytics(sagemaker_session): "TrainingInputMode": "File", "MetricDefinitions": [ {"Name": "train:loss", "Regex": "train_loss=([0-9]+\\.[0-9]+)"}, - {"Name": "validation:loss", "Regex": "valid_loss=([0-9]+\\.[0-9]+)"}, + { + "Name": "validation:loss", + "Regex": "valid_loss=([0-9]+\\.[0-9]+)", + }, ], }, }, @@ -3228,7 +3278,11 @@ def test_generic_create_model_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) assert e.get_vpc_config() is None @@ -3254,7 +3308,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) @@ -3274,7 +3332,11 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): def test_generic_deploy_accelerator_type(sagemaker_session): e = Estimator( - IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + sagemaker_session=sagemaker_session, ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) @@ -3617,7 +3679,13 @@ def test_file_output_path_not_supported_outside_local_mode(session_class): session_class.return_value = session with pytest.raises(RuntimeError): - Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path="file:///tmp/model") + Estimator( + IMAGE_URI, + ROLE, + INSTANCE_COUNT, + INSTANCE_TYPE, + output_path="file:///tmp/model", + ) def test_prepare_init_params_from_job_description_with_image_training_job(): @@ -3726,7 +3794,10 @@ def test_prepare_for_training_with_name_based_on_image(sagemaker_session): @patch("sagemaker.algorithm.AlgorithmEstimator.validate_train_spec", Mock()) -@patch("sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", Mock(return_value={})) +@patch( + "sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", + Mock(return_value={}), +) def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022", @@ -3741,7 +3812,9 @@ def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir(pipeline_session): +def test_prepare_for_training_with_pipeline_name_in_s3_path_no_source_dir( + pipeline_session, +): # script_uri is NOT provided -> use new cache key behavior that builds path using pipeline name + code_hash image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.9.0-gpu-py38" model_uri = "s3://someprefix2/models/model.tar.gz" @@ -4211,7 +4284,10 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_tags( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4240,13 +4316,20 @@ def test_all_framework_estimators_add_jumpstart_tags( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, @@ -4362,7 +4445,10 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_base_name( - patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_repack_model, + patched_upload_code, + patched_tar_and_upload_dir, + sagemaker_session, ): sagemaker_session.boto_region_name = REGION @@ -4391,13 +4477,20 @@ def test_all_framework_estimators_add_jumpstart_base_name( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, + MXNet: { + "framework_version": "1.7.0", + "py_version": "py3", + "instance_type": "ml.p2.xlarge", + }, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): + for ( + framework_estimator_class, + kwargs, + ) in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 9ba3e17ff3..f12d8e160f 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -160,14 +160,8 @@ def _get_train_args(job_name): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.4.0-cpu-py3", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index c8aad13774..5691834c3a 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -158,14 +158,8 @@ def _create_train_job(version, py_version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 2035636e76..0c0a9c6d64 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -153,14 +153,8 @@ def _create_train_job(toolkit, toolkit_version, framework): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, "retry_strategy": None, diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index c3e984e0b7..430cb484b4 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -140,14 +140,8 @@ def _create_train_job(version): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index d58c4992cd..87a853d5d0 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -154,14 +154,8 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"): "CollectionConfigurations": [], "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, - "profiler_rule_configs": [ - { - "RuleConfigurationName": "ProfilerReport-1510006209", - "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, - } - ], "profiler_config": { + "DisableProfiler": False, "S3OutputPath": "s3://{}/".format(BUCKET_NAME), }, } From 097e82947590cc9b2c68d01f155c4bb486e526b8 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Thu, 15 Dec 2022 12:39:35 -0800 Subject: [PATCH 47/58] Use Async Inference Config when available for endpoint update (#3124) Co-authored-by: Navin Soni --- src/sagemaker/session.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ce6a3b99cd..602cd1fd9f 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -3324,6 +3324,11 @@ def create_endpoint_config_from_existing( if request_data_capture_config_dict is not None: request["DataCaptureConfig"] = request_data_capture_config_dict + if existing_endpoint_config_desc.get("AsyncInferenceConfig") is not None: + request["AsyncInferenceConfig"] = existing_endpoint_config_desc.get( + "AsyncInferenceConfig", None + ) + self.sagemaker_client.create_endpoint_config(**request) def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True): From be6111b011b1045e68b18ec1bc84c0dbd9f8fb6a Mon Sep 17 00:00:00 2001 From: Carolyn Wang <32006339+carolynwang@users.noreply.github.com> Date: Thu, 15 Dec 2022 15:43:11 -0500 Subject: [PATCH 48/58] feature: Add p4de to smddp supported instance types (#3483) --- src/sagemaker/fw_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 3ba918ea2c..a91aff1761 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -80,6 +80,7 @@ "ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4d.24xlarge", + "ml.p4de.24xlarge", "local_gpu", ) SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = { From a0258bbaa715295ae15f2bf4c59cbe6eed054b07 Mon Sep 17 00:00:00 2001 From: Miyoung Date: Thu, 15 Dec 2022 13:08:57 -0800 Subject: [PATCH 49/58] documentation: smdistributed libraries release notes (#3543) --- doc/api/training/sdp_versions/latest.rst | 4 +- .../smd_data_parallel_change_log.rst | 50 +++++++++++++--- .../smd_model_parallel_change_log.rst | 60 ++++++++++++++++--- doc/api/training/smp_versions/latest.rst | 4 +- 4 files changed, 100 insertions(+), 18 deletions(-) diff --git a/doc/api/training/sdp_versions/latest.rst b/doc/api/training/sdp_versions/latest.rst index c3fcc5f78e..461f58998f 100644 --- a/doc/api/training/sdp_versions/latest.rst +++ b/doc/api/training/sdp_versions/latest.rst @@ -26,8 +26,8 @@ depending on the version of the library you use. `_ for more information. -Version 1.4.0, 1.4.1, 1.5.0 (Latest) -==================================== +Version 1.4.0, 1.4.1, 1.5.0, 1.6.0 (Latest) +=========================================== .. toctree:: :maxdepth: 1 diff --git a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst index 05eb7220e0..8ff7fabf1c 100644 --- a/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst +++ b/doc/api/training/smd_data_parallel_release_notes/smd_data_parallel_change_log.rst @@ -7,9 +7,51 @@ Release Notes New features, bug fixes, and improvements are regularly made to the SageMaker distributed data parallel library. -SageMaker Distributed Data Parallel 1.5.0 Release Notes +SageMaker Distributed Data Parallel 1.6.0 Release Notes ======================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* New optimized SMDDP AllGather collective to complement the sharded data parallelism technique + in the SageMaker model parallelism library. For more information, see `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Added support for Amazon EC2 ``ml.p4de.24xlarge`` instances. You can run data parallel training jobs + on ``ml.p4de.24xlarge`` instances with the SageMaker data parallelism library’s AllReduce collective. + +**Improvements** + +* General performance improvements of the SMDDP AllReduce collective communication operation. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + + .. code:: + + https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.1/cu113/2022-12-05/smdistributed_dataparallel-1.6.0-cp38-cp38-linux_x86_64.whl + + +---- + +Release History +=============== + +SageMaker Distributed Data Parallel 1.5.0 Release Notes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + *Date: Jul. 26. 2022* **Currency Updates** @@ -38,12 +80,6 @@ Binary file of this version of the library for `custom container https://smdataparallel.s3.amazonaws.com/binary/pytorch/1.12.0/cu113/2022-07-01/smdistributed_dataparallel-1.5.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== - SageMaker Distributed Data Parallel 1.4.1 Release Notes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst index 6f89fa45a5..92ccc8c407 100644 --- a/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst +++ b/doc/api/training/smd_model_parallel_release_notes/smd_model_parallel_change_log.rst @@ -6,9 +6,60 @@ New features, bug fixes, and improvements are regularly made to the SageMaker distributed model parallel library. -SageMaker Distributed Model Parallel 1.11.0 Release Notes +SageMaker Distributed Model Parallel 1.13.0 Release Notes ========================================================= +*Date: Dec. 15. 2022* + +**New Features** + +* Sharded data parallelism now supports a new backend for collectives called *SMDDP Collectives*. + For supported scenarios, SMDDP Collectives are on by default for the AllGather operation. + For more information, see + `Sharded data parallelism with SMDDP Collectives + `_ + in the *Amazon SageMaker Developer Guide*. +* Introduced FlashAttention for DistributedTransformer to improve memory usage and computational + performance of models such as GPT2, GPTNeo, GPTJ, GPTNeoX, BERT, and RoBERTa. + +**Bug Fixes** + +* Fixed initialization of ``lm_head`` in DistributedTransformer to use a provided range + for initialization, when weights are not tied with the embeddings. + +**Improvements** + +* When a module has no parameters, we have introduced an optimization to execute + such a module on the same rank as its parent during pipeline parallelism. + +**Migration to AWS Deep Learning Containers** + +This version passed benchmark testing and is migrated to the following AWS Deep Learning Containers (DLC): + +- SageMaker training container for PyTorch v1.12.1 + + .. code:: + + 763104351884.dkr.ecr..amazonaws.com/pytorch-training:1.12.1-gpu-py38-cu113-ubuntu20.04-sagemaker + + +Binary file of this version of the library for `custom container +`_ users: + +- For PyTorch 1.12.0 + + .. code:: + + https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.1/build-artifacts/2022-12-08-21-34/smdistributed_modelparallel-1.13.0-cp38-cp38-linux_x86_64.whl + +---- + +Release History +=============== + +SageMaker Distributed Model Parallel 1.11.0 Release Notes +--------------------------------------------------------- + *Date: August. 17. 2022* **New Features** @@ -41,12 +92,7 @@ Binary file of this version of the library for `custom container .. code:: - https://sagemaker-distributed-model-parallel.s3.us-west-2.amazonaws.com/pytorch-1.12.0/build-artifacts/2022-08-12-16-58/smdistributed_modelparallel-1.11.0-cp38-cp38-linux_x86_64.whl - ----- - -Release History -=============== + https://sagemaker-distribu SageMaker Distributed Model Parallel 1.10.1 Release Notes --------------------------------------------------------- diff --git a/doc/api/training/smp_versions/latest.rst b/doc/api/training/smp_versions/latest.rst index 1a2032c9ed..1eb358b2a3 100644 --- a/doc/api/training/smp_versions/latest.rst +++ b/doc/api/training/smp_versions/latest.rst @@ -10,8 +10,8 @@ depending on which version of the library you need to use. To use the library, reference the **Common API** documentation alongside the framework specific API documentation. -Version 1.11.0 (Latest) -=========================================== +Version 1.11.0, 1.13.0 (Latest) +=============================== To use the library, reference the Common API documentation alongside the framework specific API documentation. From 442227bdfcd852e07f0574dd94ad0b6614b12a08 Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Thu, 15 Dec 2022 13:22:09 -0800 Subject: [PATCH 50/58] feature: Doc update for TableFormatEnum (#3542) * Updated doc for table format Enum --- doc/api/prep_data/feature_store.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/api/prep_data/feature_store.rst b/doc/api/prep_data/feature_store.rst index 0e9bf25586..838558c0a4 100644 --- a/doc/api/prep_data/feature_store.rst +++ b/doc/api/prep_data/feature_store.rst @@ -73,6 +73,10 @@ Inputs :members: :show-inheritance: +.. autoclass:: sagemaker.feature_store.inputs.TableFormatEnum + :members: + :show-inheritance: + Dataset Builder *************** From 146f6bbcc5ddec990e90fba6fcd4548781b7d994 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 16 Dec 2022 00:23:36 +0000 Subject: [PATCH 51/58] prepare release v2.124.0 --- CHANGELOG.md | 17 +++++++++++++++++ VERSION | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a05b64c96f..e5cd9826ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## v2.124.0 (2022-12-16) + +### Features + + * Doc update for TableFormatEnum + * Add p4de to smddp supported instance types + * Add disable_profiler field in config and propagate changes + * Added doc update for dataset builder + +### Bug Fixes and Other Changes + + * Use Async Inference Config when available for endpoint update + +### Documentation Changes + + * smdistributed libraries release notes + ## v2.123.0 (2022-12-15) ### Features diff --git a/VERSION b/VERSION index ea5085760e..67d5c2730e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.123.1.dev0 +2.124.0 From e07f94414385f6b513e249a61ab64b1664d49b42 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 16 Dec 2022 00:23:37 +0000 Subject: [PATCH 52/58] update development version to v2.124.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 67d5c2730e..97d160799c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.124.0 +2.124.1.dev0 From 53108b6a7b7af9f08cd00e86e34a5afeb43f8715 Mon Sep 17 00:00:00 2001 From: Xiaoguang Chen <68292680+xgchena@users.noreply.github.com> Date: Fri, 16 Dec 2022 09:29:45 -0800 Subject: [PATCH 53/58] fix: Correct SageMaker Clarify API docstrings by changing JSONPath to JMESPath (#3511) --- src/sagemaker/clarify.py | 30 +++++++++---------- .../model_monitor/clarify_model_monitoring.py | 23 +++++++------- .../model_monitor/model_monitoring.py | 16 ++++++---- src/sagemaker/workflow/clarify_check_step.py | 4 +-- 4 files changed, 39 insertions(+), 34 deletions(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index f082679401..18fed12042 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -330,11 +330,11 @@ def __init__( s3_analysis_config_output_path (str): S3 prefix to store the analysis config output. If this field is None, then the ``s3_output_path`` will be used to store the ``analysis_config`` output. - label (str): Target attribute of the model required by bias metrics. - Specified as column name or index for CSV dataset or as JSONPath for JSONLines. + label (str): Target attribute of the model required by bias metrics. Specified as + column name or index for CSV dataset or as JMESPath expression for JSONLines. *Required parameter* except for when the input dataset does not contain the label. - features (List[str]): JSONPath for locating the feature columns for bias metrics if the - dataset format is JSONLines. + features (List[str]): JMESPath expression to locate the feature columns for + bias metrics if the dataset format is JSONLines. dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV, ``"application/jsonlines"`` for JSONLines, and ``"application/x-parquet"`` for Parquet. @@ -716,11 +716,11 @@ def __init__( ``label_headers=['cat','dog','fish']`` and infer the predicted label to be ``'fish'``. Args: - label (str or int): Index or JSONPath location in the model output for the prediction. - In case, this is a predicted label of the same type as the label in the dataset, - no further arguments need to be specified. - probability (str or int): Index or JSONPath location in the model output - for the predicted score(s). + label (str or int): Index or JMESPath expression to locate the prediction + in the model output. In case, this is a predicted label of the same type + as the label in the dataset, no further arguments need to be specified. + probability (str or int): Index or JMESPath expression to locate the predicted score(s) + in the model output. probability_threshold (float): An optional value for binary prediction tasks in which the model returns a probability, to indicate the threshold to convert the prediction to a boolean value. Default is ``0.5``. @@ -1645,9 +1645,9 @@ def run_explainability( You can request multiple methods at once by passing in a list of `~sagemaker.clarify.ExplainabilityConfig`. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. @@ -1774,9 +1774,9 @@ def run_bias_and_explainability( str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig` ): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of :class:`~sagemaker.clarify.SageMakerClarifyProcessor` to provide more parameters like ``label_headers``. wait (bool): Whether the call should wait until the job completes (default: True). logs (bool): Whether to show the logs produced by the job. diff --git a/src/sagemaker/model_monitor/clarify_model_monitoring.py b/src/sagemaker/model_monitor/clarify_model_monitoring.py index 1a788a0d53..030de7c6db 100644 --- a/src/sagemaker/model_monitor/clarify_model_monitoring.py +++ b/src/sagemaker/model_monitor/clarify_model_monitoring.py @@ -842,8 +842,8 @@ def __init__(self, bias_config, headers=None, label=None): bias_config (sagemaker.clarify.BiasConfig): Config object related to bias configurations. headers (list[str]): A list of column names in the input dataset. - label (str): Target attribute for the model required by bias metrics. - Specified as column name or index for CSV dataset, or as JSONPath for JSONLines. + label (str): Target attribute for the model required by bias metrics. Specified as + column name or index for CSV dataset, or as JMESPath expression for JSONLines. """ self.analysis_config = bias_config.get_config() if headers is not None: @@ -889,9 +889,10 @@ def suggest_baseline( model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its endpoint to be created. model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`): - Index or JSONPath to locate the predicted scores in the model output. This is not - required if the model output is a single score. Alternatively, it can be an instance - of ModelPredictedLabelConfig to provide more parameters like label_headers. + Index or JMESPath expression to locate the predicted scores in the model output. + This is not required if the model output is a single score. Alternatively, + it can be an instance of ModelPredictedLabelConfig to provide more parameters + like label_headers. wait (bool): Whether the call should wait until the job completes (default: False). logs (bool): Whether to show the logs produced by the job. Only meaningful when wait is True (default: False). @@ -1302,12 +1303,12 @@ def __init__( Args: analysis_config (BiasAnalysisConfig or ExplainabilityAnalysisConfig): analysis config from configurations of the baselining job. - features_attribute (str): JSONpath to locate features in predictor request payload. - Only required when predictor content type is JSONlines. - inference_attribute (str): Index, header or JSONpath to locate predicted label in - predictor response payload. - probability_attribute (str): Index or JSONpath location in the model output for - probabilities or scores to be used for explainability. + features_attribute (str): JMESPath expression to locate features in predictor request + payload. Only required when predictor content type is JSONlines. + inference_attribute (str): Index, header or JMESPath expression to locate predicted + label in predictor response payload. + probability_attribute (str): Index or JMESPath expression to locate probabilities or + scores in the model output for computing feature attribution. probability_threshold_attribute (float): Value to indicate the threshold to select the binary label in the case of binary classification. Default is 0.5. """ diff --git a/src/sagemaker/model_monitor/model_monitoring.py b/src/sagemaker/model_monitor/model_monitoring.py index 817d951255..2f8266a43a 100644 --- a/src/sagemaker/model_monitor/model_monitoring.py +++ b/src/sagemaker/model_monitor/model_monitoring.py @@ -1061,12 +1061,13 @@ def _generate_env_map( dataset_format (dict): The format of the baseline_dataset. dataset_source_container_path (str): The path to the dataset source. inference_attribute (str): Index or JSONpath to locate predicted label(s). - Only used for ModelQualityMonitor, ModelBiasMonitor, and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. probability_attribute (str or int): Index or JSONpath to locate probabilities. - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor - ground_truth_attribute (str): Index or JSONpath to locate actual label(s). + Only used for ModelQualityMonitor. + ground_truth_attribute (str): Index to locate actual label(s). + Only used for ModelQualityMonitor. probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. Returns: dict: Dictionary of environment keys and values. @@ -2600,10 +2601,13 @@ def suggest_baseline( problem_type (str): The type of problem of this model quality monitoring. Valid values are "Regression", "BinaryClassification", "MulticlassClassification". inference_attribute (str): Index or JSONpath to locate predicted label(s). + Only used for ModelQualityMonitor. probability_attribute (str or int): Index or JSONpath to locate probabilities. - ground_truth_attribute (str): Index or JSONpath to locate actual label(s). + Only used for ModelQualityMonitor. + ground_truth_attribute (str): Index to locate actual label(s). + Only used for ModelQualityMonitor. probability_threshold_attribute (float): threshold to convert probabilities to binaries - Only used for ModelQualityMonitor, ModelBiasMonitor and ModelExplainabilityMonitor + Only used for ModelQualityMonitor. post_analytics_processor_script (str): The path to the record post-analytics processor script. This can be a local path or an S3 uri. output_s3_uri (str): Desired S3 destination Destination of the constraint_violations diff --git a/src/sagemaker/workflow/clarify_check_step.py b/src/sagemaker/workflow/clarify_check_step.py index 9d350b01f3..22b6fc2051 100644 --- a/src/sagemaker/workflow/clarify_check_step.py +++ b/src/sagemaker/workflow/clarify_check_step.py @@ -132,8 +132,8 @@ class ModelExplainabilityCheckConfig(ClarifyCheckConfig): model_config (ModelConfig): Config of the model and its endpoint to be created. explainability_config (SHAPConfig): Config of the specific explainability method. Currently, only SHAP is supported. - model_scores (str or int or ModelPredictedLabelConfig): Index or JSONPath location - in the model output for the predicted scores to be explained (default: None). + model_scores (str or int or ModelPredictedLabelConfig): Index or JMESPath expression + to locate the predicted scores in the model output (default: None). This is not required if the model output is a single score. Alternatively, an instance of ModelPredictedLabelConfig can be provided but this field CANNOT be any type of the `PipelineVariable`. From ea0d053266c20ceadd1e047c0ad01e40121583ca Mon Sep 17 00:00:00 2001 From: Tim Song <4277459+timyber@users.noreply.github.com> Date: Fri, 16 Dec 2022 18:52:29 +0100 Subject: [PATCH 54/58] feature: add RandomSeed to support reproducible HPO (#3519) * feature: add RandomSeed to support reproducible HPO * fix pylint Co-authored-by: Tim Song Co-authored-by: Rajanikant Tenguria --- src/sagemaker/session.py | 12 ++++++++++++ src/sagemaker/tuner.py | 18 ++++++++++++++++++ tests/unit/test_session.py | 6 ++++++ tests/unit/test_tuner.py | 1 + tests/unit/tuner_test_utils.py | 1 + 5 files changed, 38 insertions(+) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 602cd1fd9f..5404978200 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2146,6 +2146,7 @@ def tune( # noqa: C901 use_spot_instances=False, checkpoint_s3_uri=None, checkpoint_local_path=None, + random_seed=None, ): """Create an Amazon SageMaker hyperparameter tuning job. @@ -2226,6 +2227,9 @@ def tune( # noqa: C901 started. If the path is unset then SageMaker assumes the checkpoints will be provided under `/opt/ml/checkpoints/`. (default: ``None``). + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. (default: ``None``). """ tune_request = { @@ -2238,6 +2242,7 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + random_seed=random_seed, strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( @@ -2394,6 +2399,7 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + random_seed=None, strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2412,6 +2418,9 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. strategy_config (dict): A configuration for the hyperparameter tuning job optimisation strategy. @@ -2430,6 +2439,9 @@ def _map_tuning_config( "TrainingJobEarlyStoppingType": early_stopping_type, } + if random_seed is not None: + tuning_config["RandomSeed"] = random_seed + tuning_objective = cls._map_tuning_objective(objective_type, objective_metric_name) if tuning_objective is not None: tuning_config["HyperParameterTuningJobObjective"] = tuning_objective diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 9a694cbec9..45a6467c1f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -413,6 +413,7 @@ def __init__( strategy_config: Optional[StrategyConfig] = None, early_stopping_type: Union[str, PipelineVariable] = "Off", estimator_name: Optional[str] = None, + random_seed: Optional[int] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -470,6 +471,9 @@ def __init__( estimator_name (str): A unique name to identify an estimator within the hyperparameter tuning job, when more than one estimator is used with the same tuning job (default: None). + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. """ if hyperparameter_ranges is None or len(hyperparameter_ranges) == 0: raise ValueError("Need to specify hyperparameter ranges") @@ -516,6 +520,7 @@ def __init__( self.latest_tuning_job = None self.warm_start_config = warm_start_config self.early_stopping_type = early_stopping_type + self.random_seed = random_seed def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False): """Prepare the tuner instance for tuning (fit).""" @@ -1222,6 +1227,9 @@ def _prepare_init_params_from_job_description(cls, job_details): "base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]), } + if "RandomSeed" in tuning_config: + params["random_seed"] = tuning_config["RandomSeed"] + if "HyperParameterTuningJobObjective" in tuning_config: params["objective_metric_name"] = tuning_config["HyperParameterTuningJobObjective"][ "MetricName" @@ -1483,6 +1491,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato warm_start_type=warm_start_type, parents=all_parents ), early_stopping_type=self.early_stopping_type, + random_seed=self.random_seed, ) if len(self.estimator_dict) > 1: @@ -1508,6 +1517,7 @@ def _create_warm_start_tuner(self, additional_parents, warm_start_type, estimato max_parallel_jobs=self.max_parallel_jobs, warm_start_config=WarmStartConfig(warm_start_type=warm_start_type, parents=all_parents), early_stopping_type=self.early_stopping_type, + random_seed=self.random_seed, ) @classmethod @@ -1526,6 +1536,7 @@ def create( tags=None, warm_start_config=None, early_stopping_type="Off", + random_seed=None, ): """Factory method to create a ``HyperparameterTuner`` instance. @@ -1586,6 +1597,9 @@ def create( Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. If set to 'Auto', early stopping of some training jobs may happen, but is not guaranteed to. + random_seed (int): An initial value used to initialize a pseudo-random number generator. + Setting a random seed will make the hyperparameter tuning search strategies to + produce more consistent configurations for the same tuning job. Returns: sagemaker.tuner.HyperparameterTuner: a new ``HyperparameterTuner`` object that can @@ -1624,6 +1638,7 @@ def create( tags=tags, warm_start_config=warm_start_config, early_stopping_type=early_stopping_type, + random_seed=random_seed, ) for estimator_name in estimator_names[1:]: @@ -1775,6 +1790,9 @@ def _get_tuner_args(cls, tuner, inputs): "early_stopping_type": tuner.early_stopping_type, } + if tuner.random_seed is not None: + tuning_config["random_seed"] = tuner.random_seed + if tuner.strategy_config is not None: tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index ec4a21cbc9..119d08cef4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -897,6 +897,7 @@ def test_train_pack_to_request(sagemaker_session): "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, "ParameterRanges": SAMPLE_PARAM_RANGES, "TrainingJobEarlyStoppingType": "Off", + "RandomSeed": 0, }, "TrainingJobDefinition": { "StaticHyperParameters": STATIC_HPs, @@ -989,6 +990,7 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1080,6 +1082,7 @@ def assert_create_tuning_job_request(**kwrags): "max_jobs": 100, "max_parallel_jobs": 5, "parameter_ranges": SAMPLE_PARAM_RANGES, + "random_seed": 0, }, training_config={ "static_hyperparameters": STATIC_HPs, @@ -1170,6 +1173,7 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1246,6 +1250,7 @@ def assert_create_tuning_job_request(**kwrags): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, @@ -1289,6 +1294,7 @@ def assert_create_tuning_job_request(**kwargs): sagemaker_session.tune( job_name="dummy-tuning-1", strategy="Bayesian", + random_seed=0, objective_type="Maximize", objective_metric_name="val-score", max_jobs=100, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 9bbc882dfa..7e556c7d23 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -545,6 +545,7 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session assert tuner.strategy == "Bayesian" assert tuner.objective_type == "Minimize" assert tuner.early_stopping_type == "Off" + assert tuner.random_seed == 0 assert isinstance(tuner.estimator, PCA) assert tuner.estimator.role == ROLE diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index be0dba2ccd..5cf7ba2fc2 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -112,6 +112,7 @@ ], }, "TrainingJobEarlyStoppingType": "Off", + "RandomSeed": 0, }, "HyperParameterTuningJobName": JOB_NAME, "TrainingJobDefinition": { From b7594402ba45579bccd8dfe1e016e55f0e659ef6 Mon Sep 17 00:00:00 2001 From: Harsha Balluru Date: Fri, 16 Dec 2022 20:28:34 +0000 Subject: [PATCH 55/58] tensorflow inference 2.10.1 release --- src/sagemaker/image_uri_config/tensorflow.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6bb36057fa..35c77f1d1d 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -307,7 +307,7 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0", + "2.10": "2.10.1", "2.11": "2.11.0" }, "versions": { @@ -1672,7 +1672,7 @@ }, "repository": "tensorflow-inference" }, - "2.10.0": { + "2.10.1": { "registries": { "af-south-1": "626614931356", "ap-east-1": "871362719292", From 62f77444346202bd9619f7f2fc7e8b56b5c4539e Mon Sep 17 00:00:00 2001 From: Harsha Balluru Date: Sat, 17 Dec 2022 00:47:09 +0000 Subject: [PATCH 56/58] added back 2.10.0 images --- .../image_uri_config/tensorflow.json | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 35c77f1d1d..46b3c02f14 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -1672,6 +1672,41 @@ }, "repository": "tensorflow-inference" }, + "2.10.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, "2.10.1": { "registries": { "af-south-1": "626614931356", From 27ec1f06e18e6edc0761abad1ba5f012dc6a533b Mon Sep 17 00:00:00 2001 From: Harsha Balluru Date: Wed, 21 Dec 2022 01:29:45 +0000 Subject: [PATCH 57/58] restoring original file to prevent conflicts --- .../image_uri_config/tensorflow.json | 37 +------------------ 1 file changed, 1 insertion(+), 36 deletions(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 46b3c02f14..6bb36057fa 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -307,7 +307,7 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.1", + "2.10": "2.10.0", "2.11": "2.11.0" }, "versions": { @@ -1707,41 +1707,6 @@ }, "repository": "tensorflow-inference" }, - "2.10.1": { - "registries": { - "af-south-1": "626614931356", - "ap-east-1": "871362719292", - "ap-northeast-1": "763104351884", - "ap-northeast-2": "763104351884", - "ap-northeast-3": "364406365360", - "ap-south-1": "763104351884", - "ap-south-2": "772153158452", - "ap-southeast-1": "763104351884", - "ap-southeast-2": "763104351884", - "ap-southeast-3": "907027046896", - "ca-central-1": "763104351884", - "cn-north-1": "727897471807", - "cn-northwest-1": "727897471807", - "eu-central-1": "763104351884", - "eu-central-2": "380420809688", - "eu-north-1": "763104351884", - "eu-south-1": "692866216735", - "eu-south-2": "503227376785", - "eu-west-1": "763104351884", - "eu-west-2": "763104351884", - "eu-west-3": "763104351884", - "me-south-1": "217643126080", - "sa-east-1": "763104351884", - "us-east-1": "763104351884", - "us-east-2": "763104351884", - "us-gov-east-1": "446045086412", - "us-gov-west-1": "442386744353", - "us-iso-east-1": "886529160074", - "us-west-1": "763104351884", - "us-west-2": "763104351884" - }, - "repository": "tensorflow-inference" - }, "2.11.0": { "registries": { "af-south-1": "626614931356", From 72c987703961191e7fcc05c76481b3785d9a1cfd Mon Sep 17 00:00:00 2001 From: Harsha Balluru Date: Wed, 21 Dec 2022 01:55:19 +0000 Subject: [PATCH 58/58] adding tf 2.10.1 release images --- .../image_uri_config/tensorflow.json | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6bb36057fa..cb206c31a4 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -307,7 +307,7 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0", + "2.10": "2.10.1", "2.11": "2.11.0" }, "versions": { @@ -1707,6 +1707,41 @@ }, "repository": "tensorflow-inference" }, + "2.10.1": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, "2.11.0": { "registries": { "af-south-1": "626614931356", @@ -3347,3 +3382,4 @@ } } } +