diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 28a4f6ec4c..e012be566a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -20,6 +20,8 @@ import uuid from abc import ABCMeta, abstractmethod from typing import Any, Dict, Union, Optional, List +from packaging.specifiers import SpecifierSet +from packaging.version import Version from six import string_types, with_metaclass from six.moves.urllib.parse import urlparse @@ -83,10 +85,7 @@ ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.pipeline_context import ( - PipelineSession, - runnable_by_pipeline, -) +from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline logger = logging.getLogger(__name__) @@ -106,6 +105,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man LAUNCH_PS_ENV_NAME = "sagemaker_parameter_server_enabled" LAUNCH_MPI_ENV_NAME = "sagemaker_mpi_enabled" LAUNCH_SM_DDP_ENV_NAME = "sagemaker_distributed_dataparallel_enabled" + LAUNCH_MWMS_ENV_NAME = "sagemaker_multi_worker_mirrored_strategy_enabled" INSTANCE_TYPE = "sagemaker_instance_type" MPI_NUM_PROCESSES_PER_HOST = "sagemaker_mpi_num_of_processes_per_host" MPI_CUSTOM_MPI_OPTIONS = "sagemaker_mpi_custom_mpi_options" @@ -557,9 +557,7 @@ def __init__( self.dependencies = dependencies or [] self.uploaded_code = None self.tags = add_jumpstart_tags( - tags=tags, - training_model_uri=self.model_uri, - training_script_uri=self.source_dir, + tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir ) if self.instance_type in ("local", "local_gpu"): if self.instance_type == "local_gpu" and self.instance_count > 1: @@ -680,8 +678,7 @@ def _ensure_base_job_name(self): self.base_job_name or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri) or base_name_from_image( - self.training_image_uri(), - default_base_name=EstimatorBase.JOB_CLASS_NAME, + self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME ) ) @@ -744,7 +741,6 @@ def _prepare_for_training(self, job_name=None): self.dependencies = updated_paths["dependencies"] if self.source_dir or self.entry_point or self.dependencies: - # validate source dir will raise a ValueError if there is something wrong with # the source directory. We are intentionally not handling it because this is a # critical error. @@ -1023,10 +1019,7 @@ def _set_source_s3_uri(self, rule): parse_result = urlparse(rule.rule_parameters["source_s3_uri"]) if parse_result.scheme != "s3": desired_s3_uri = os.path.join( - "s3://", - self.sagemaker_session.default_bucket(), - rule.name, - str(uuid.uuid4()), + "s3://", self.sagemaker_session.default_bucket(), rule.name, str(uuid.uuid4()) ) s3_uri = S3Uploader.upload( local_path=rule.rule_parameters["source_s3_uri"], @@ -1439,10 +1432,7 @@ def deploy( self._ensure_base_job_name() jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model( - kwargs.get("source_dir"), - self.source_dir, - kwargs.get("model_data"), - self.model_uri, + kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri ) default_name = ( name_from_base(jumpstart_base_name) @@ -2240,11 +2230,7 @@ def _is_local_channel(cls, input_uri): @classmethod def update( - cls, - estimator, - profiler_rule_configs=None, - profiler_config=None, - resource_config=None, + cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None ): """Update a running Amazon SageMaker training job. @@ -3165,6 +3151,34 @@ def _validate_and_set_debugger_configs(self): ) self.debugger_hook_config = False + def _validate_mwms_config(self, distribution): + """Validate Multi Worker Mirrored Strategy configuration.""" + minimum_supported_framework_version = {"tensorflow": {"framework_version": "2.9"}} + if self._framework_name in minimum_supported_framework_version: + for version_argument in minimum_supported_framework_version[self._framework_name]: + current = getattr(self, version_argument) + threshold = minimum_supported_framework_version[self._framework_name][ + version_argument + ] + if Version(current) in SpecifierSet(f"< {threshold}"): + raise ValueError( + "Multi Worker Mirrored Strategy is only supported " + "from {} {} but received {}".format(version_argument, threshold, current) + ) + else: + raise ValueError( + "Multi Worker Mirrored Strategy is currently only supported " + "with {} frameworks but received {}".format( + minimum_supported_framework_version.keys(), self._framework_name + ) + ) + unsupported_distributions = ["smdistributed", "parameter_server"] + if any(i in distribution for i in unsupported_distributions): + raise ValueError( + "Multi Worker Mirrored Strategy is currently not supported with the" + " following distribution strategies: {}".format(unsupported_distributions) + ) + def _model_source_dir(self): """Get the appropriate value to pass as ``source_dir`` to a model constructor. @@ -3528,6 +3542,12 @@ def _distribution_configuration(self, distribution): "dataparallel" ].get("custom_mpi_options", "") + if "multi_worker_mirrored_strategy" in distribution: + mwms_enabled = distribution.get("multi_worker_mirrored_strategy").get("enabled", False) + if mwms_enabled: + self._validate_mwms_config(distribution) + distribution_config[self.LAUNCH_MWMS_ENV_NAME] = mwms_enabled + if not (mpi_enabled or smdataparallel_enabled) and distribution_config.get( "sagemaker_distribution_instance_groups" ) not in [None, []]: diff --git a/src/sagemaker/tensorflow/estimator.py b/src/sagemaker/tensorflow/estimator.py index a2507e2bc2..c7463dfc03 100644 --- a/src/sagemaker/tensorflow/estimator.py +++ b/src/sagemaker/tensorflow/estimator.py @@ -137,6 +137,23 @@ def __init__( To find a complete list of parameters for SageMaker model parallelism, see :ref:`sm-sdk-modelparallel-general`. + **To enable Multi Worker Mirrored Strategy:** + + .. code:: python + + { + "multi_worker_mirrored_strategy": { + "enabled": True + } + } + + This distribution strategy option is available for TensorFlow 2.9 and later in + the SageMaker Python SDK v2.xx.yy and later. + To learn more about the mirrored strategy for TensorFlow, + see `TensorFlow Distributed Training + `_ + in the *TensorFlow documentation*. + **To enable MPI:** .. code:: python diff --git a/src/sagemaker/tensorflow/training_compiler/config.py b/src/sagemaker/tensorflow/training_compiler/config.py index 16c4b1fe70..6c897d1723 100644 --- a/src/sagemaker/tensorflow/training_compiler/config.py +++ b/src/sagemaker/tensorflow/training_compiler/config.py @@ -79,7 +79,7 @@ def validate(cls, estimator): """Checks if SageMaker Training Compiler is configured correctly. Args: - estimator (str): A estimator object + estimator (:class:`sagemaker.tensorflow.estimator.TensorFlow`): A estimator object If SageMaker Training Compiler is enabled, it will validate whether the estimator is configured to be compatible with Training Compiler. @@ -102,3 +102,13 @@ def validate(cls, estimator): cls.MIN_SUPPORTED_VERSION, estimator.framework_version ) raise ValueError(error_helper_string) + + if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution: + mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get( + "enabled", False + ) + if mwms_enabled: + raise ValueError( + "Multi Worker Mirrored Strategy distributed training configuration " + "is currently not compatible with SageMaker Training Compiler." + ) diff --git a/tests/conftest.py b/tests/conftest.py index 113e28138b..ae2a1ca06d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -281,9 +281,7 @@ def huggingface_training_compiler_pytorch_version( huggingface_training_compiler_version, ): versions = _huggingface_base_fm_version( - huggingface_training_compiler_version, - "pytorch", - "huggingface_training_compiler", + huggingface_training_compiler_version, "pytorch", "huggingface_training_compiler" ) if not versions: pytest.skip( @@ -298,9 +296,7 @@ def huggingface_training_compiler_tensorflow_version( huggingface_training_compiler_version, ): versions = _huggingface_base_fm_version( - huggingface_training_compiler_version, - "tensorflow", - "huggingface_training_compiler", + huggingface_training_compiler_version, "tensorflow", "huggingface_training_compiler" ) if not versions: pytest.skip( @@ -526,8 +522,7 @@ def pytorch_ddp_py_version(): @pytest.fixture( - scope="module", - params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"], + scope="module", params=["1.10", "1.10.0", "1.10.2", "1.11", "1.11.0", "1.12", "1.12.0"] ) def pytorch_ddp_framework_version(request): return request.param diff --git a/tests/data/tensorflow_mnist/mnist_mwms.py b/tests/data/tensorflow_mnist/mnist_mwms.py new file mode 100644 index 0000000000..728e479f5d --- /dev/null +++ b/tests/data/tensorflow_mnist/mnist_mwms.py @@ -0,0 +1,57 @@ +# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras + +import json +import os +import tensorflow as tf +import numpy as np + + +def mnist_dataset(batch_size): + (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() + # The `x` arrays are in uint8 and have values in the [0, 255] range. + # You need to convert them to float32 with values in the [0, 1] range. + x_train = x_train / np.float32(255) + y_train = y_train.astype(np.int64) + train_dataset = ( + tf.data.Dataset.from_tensor_slices((x_train, y_train)) + .shuffle(60000) + .repeat() + .batch(batch_size) + ) + return train_dataset + + +def build_and_compile_cnn_model(): + model = tf.keras.Sequential( + [ + tf.keras.layers.InputLayer(input_shape=(28, 28)), + tf.keras.layers.Reshape(target_shape=(28, 28, 1)), + tf.keras.layers.Conv2D(32, 3, activation="relu"), + tf.keras.layers.Flatten(), + tf.keras.layers.Dense(128, activation="relu"), + tf.keras.layers.Dense(10), + ] + ) + model.compile( + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), + metrics=["accuracy"], + ) + return model + + +per_worker_batch_size = 64 +tf_config = json.loads(os.environ["TF_CONFIG"]) +num_workers = len(tf_config["cluster"]["worker"]) + +strategy = tf.distribute.MultiWorkerMirroredStrategy() + +global_batch_size = per_worker_batch_size * num_workers +multi_worker_dataset = mnist_dataset(global_batch_size) + +with strategy.scope(): + multi_worker_model = build_and_compile_cnn_model() + +multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70) + +print(f"strategy.num_replicas_in_sync={strategy.num_replicas_in_sync}") diff --git a/tests/integ/test_tf.py b/tests/integ/test_tf.py index 86ac20e9bf..88d0bbd3e8 100644 --- a/tests/integ/test_tf.py +++ b/tests/integ/test_tf.py @@ -38,6 +38,7 @@ SCRIPT = "mnist.py" PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}} MPI_DISTRIBUTION = {"mpi": {"enabled": True}} +MWMS_DISTRIBUTION = {"multi_worker_mirrored_strategy": {"enabled": True}} TAGS = [{"Key": "some-key", "Value": "some-value"}] ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} @@ -68,12 +69,7 @@ def test_framework_processing_job_with_deps( sagemaker_session=sagemaker_session, base_job_name="test-tensorflow", ) - processor.run( - code=entry_point, - source_dir=code_path, - inputs=[], - wait=True, - ) + processor.run(code=entry_point, source_dir=code_path, inputs=[], wait=True) def test_mnist_with_checkpoint_config( @@ -110,9 +106,7 @@ def test_mnist_with_checkpoint_config( with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(inputs=inputs, job_name=training_job_name) assert_s3_file_patterns_exist( - sagemaker_session, - estimator.model_dir, - [r"model\.ckpt-\d+\.index", r"checkpoint"], + sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"] ) # remove dataframe assertion to unblock PR build # TODO: add independent integration test for `training_job_analytics` @@ -130,9 +124,7 @@ def test_mnist_with_checkpoint_config( ] ) - expected_retry_strategy = { - "MaximumRetryAttempts": 2, - } + expected_retry_strategy = {"MaximumRetryAttempts": 2} actual_retry_strategy = sagemaker_session.sagemaker_client.describe_training_job( TrainingJobName=training_job_name )["RetryStrategy"] @@ -181,6 +173,48 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_v ) +@pytest.mark.release +@pytest.mark.skipif( + tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS + and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS, + reason="no ml.p2 or ml.p3 instances in this region", +) +@retry_with_instance_list(gpu_list(tests.integ.test_region())) +def test_mwms_gpu( + sagemaker_session, + tensorflow_training_latest_version, + tensorflow_training_latest_py_version, + capsys, + **kwargs, +): + instance_count = 2 + estimator = TensorFlow( + source_dir=os.path.join(RESOURCE_PATH, "tensorflow_mnist"), + entry_point="mnist_mwms.py", + model_dir=False, + instance_type=kwargs["instance_type"], + instance_count=instance_count, + framework_version=tensorflow_training_latest_version, + py_version=tensorflow_training_latest_py_version, + distribution=MWMS_DISTRIBUTION, + environment={"NCCL_DEBUG": "INFO"}, + max_run=60 * 60 * 1, # 1 hour + role=ROLE, + volume_size=400, + sagemaker_session=sagemaker_session, + disable_profiler=True, + ) + + with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): + estimator.fit(job_name=unique_name_from_base("test-tf-mwms")) + + captured = capsys.readouterr() + logs = captured.out + captured.err + print(logs) + assert "Running distributed training job with multi_worker_mirrored_strategy setup" in logs + assert f"strategy.num_replicas_in_sync={instance_count}" in logs + + @pytest.mark.release def test_mnist_distributed_cpu( sagemaker_session, @@ -237,9 +271,7 @@ def _create_and_fit_estimator(sagemaker_session, tf_version, py_version, instanc with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES): estimator.fit(inputs=inputs, job_name=unique_name_from_base("test-tf-sm-distributed")) assert_s3_file_patterns_exist( - sagemaker_session, - estimator.model_dir, - [r"model\.ckpt-\d+\.index", r"checkpoint"], + sagemaker_session, estimator.model_dir, [r"model\.ckpt-\d+\.index", r"checkpoint"] ) @@ -346,8 +378,7 @@ def test_model_deploy_with_serverless_inference_config( sagemaker_session=sagemaker_session, ) predictor = model.deploy( - serverless_inference_config=ServerlessInferenceConfig(), - endpoint_name=endpoint_name, + serverless_inference_config=ServerlessInferenceConfig(), endpoint_name=endpoint_name ) input_data = {"instances": [1.0, 2.0, 5.0]} diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 10bd809bc4..689ca66c6e 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -41,7 +41,8 @@ def instance_count(request): @pytest.fixture(scope="module") def imagenet_val_set(request, sagemaker_session, tmpdir_factory): """ - Copies the dataset from the bucket it's hosted in to the local bucket in the test region + Copies the Imagenet dataset from the bucket it's hosted in to the local bucket in the test region. + Due to licensing issues, access to this dataset is controlled through an allowlist """ local_path = tmpdir_factory.mktemp("trcomp_imagenet_val_set") sagemaker_session.download_data( @@ -148,7 +149,6 @@ def test_pytorch( Test the PyTorch estimator """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - hf = PyTorch( py_version="py39", source_dir=os.path.join(DATA_DIR, "huggingface_byoc"), diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index 78e4a0d281..aaadbc98d5 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -42,6 +42,7 @@ "520713654638.dkr.ecr.{}.amazonaws.com/sagemaker-tensorflow-scriptmode:{}-cpu-{}" ) DISTRIBUTION_PS_ENABLED = {"parameter_server": {"enabled": True}} +DISTRIBUTION_MWMS_ENABLED = {"multi_worker_mirrored_strategy": {"enabled": True}} DISTRIBUTION_MPI_ENABLED = { "mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2} } @@ -519,6 +520,99 @@ def test_fit_mpi(time, strftime, sagemaker_session): assert actual_train_args == expected_train_args +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_fit_mwms( + time, strftime, sagemaker_session, tensorflow_training_version, tensorflow_training_py_version +): + if version.Version(tensorflow_training_version) < version.Version("2.11"): + pytest.skip("Multi Worker Mirrored Strategy was added in TF 2.11") + framework_version = tensorflow_training_version + py_version = tensorflow_training_py_version + tf = TensorFlow( + entry_point=SCRIPT_FILE, + framework_version=framework_version, + py_version=py_version, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type=INSTANCE_TYPE, + instance_count=1, + source_dir=DATA_DIR, + distribution=DISTRIBUTION_MWMS_ENABLED, + ) + + inputs = "s3://mybucket/train" + tf.fit(inputs=inputs) + + call_names = [c[0] for c in sagemaker_session.method_calls] + assert call_names == ["train", "logs_for_job"] + + expected_train_args = _create_train_job(framework_version, py_version=py_version) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args[ + "image_uri" + ] = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}" + expected_train_args["job_name"] = f"tensorflow-training-{TIMESTAMP}" + expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"]["sagemaker_job_name"] = json.dumps( + expected_train_args["job_name"] + ) + expected_train_args["hyperparameters"]["sagemaker_submit_directory"] = json.dumps( + f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/source/sourcedir.tar.gz" + ) + expected_train_args["hyperparameters"]["model_dir"] = json.dumps( + f"s3://{BUCKET_NAME}/{expected_train_args['job_name']}/model" + ) + expected_train_args["enable_sagemaker_metrics"] = True + + actual_train_args = sagemaker_session.method_calls[0][2] + assert actual_train_args == expected_train_args + + +@patch("time.strftime", return_value=TIMESTAMP) +@patch("time.time", return_value=TIME) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_fit_mwms_unsupported(time, strftime, sagemaker_session): + with pytest.raises(ValueError) as error: + tf = TensorFlow( + entry_point=SCRIPT_FILE, + framework_version="2.8", + py_version="py39", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type=INSTANCE_TYPE, + instance_count=1, + source_dir=DATA_DIR, + distribution=DISTRIBUTION_MWMS_ENABLED, + ) + inputs = "s3://mybucket/train" + tf.fit(inputs=inputs) + + assert "only supported from" in str(error) + assert "but received" in str(error) + + with pytest.raises(ValueError) as error: + tf = TensorFlow( + entry_point=SCRIPT_FILE, + framework_version="2.10", + py_version="py39", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_type="ml.p4d.24xlarge", + instance_count=4, + source_dir=DATA_DIR, + distribution={ + **DISTRIBUTION_MWMS_ENABLED, + **{"smdistributed": {"dataparallel": {"enabled": True}}}, + }, + ) + inputs = "s3://mybucket/train" + tf.fit(inputs=inputs) + assert "is currently not supported" in str(error) + assert "following distribution strategies" in str(error) + + def test_hyperparameters_no_model_dir( sagemaker_session, tensorflow_training_version, tensorflow_training_py_version ): @@ -552,10 +646,7 @@ def test_tf_heterogeneous_cluster_distribution_config( framework_version=tensorflow_training_version, py_version=tensorflow_training_py_version, instance_groups=[training_group], - distribution={ - "mpi": {"enabled": True}, - "instance_groups": [training_group], - }, + distribution={"mpi": {"enabled": True}, "instance_groups": [training_group]}, ) assert tf.distribution == expected_return diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 5a8fce34ef..f6700bf51f 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -156,10 +156,7 @@ def _create_train_job(framework_version, instance_type, training_compiler_config class TestUnsupportedConfig: def test_cpu_instance( - self, - cpu_instance_type, - tensorflow_training_version, - tensorflow_training_py_version, + self, cpu_instance_type, tensorflow_training_version, tensorflow_training_py_version ): with pytest.raises(ValueError): TensorFlow( @@ -192,10 +189,7 @@ def test_gpu_instance( compiler_config=TrainingCompilerConfig(), ).fit() - def test_framework_version( - self, - tensorflow_training_py_version, - ): + def test_framework_version(self, tensorflow_training_py_version): with pytest.raises(ValueError): TensorFlow( py_version=tensorflow_training_py_version, @@ -208,10 +202,21 @@ def test_framework_version( compiler_config=TrainingCompilerConfig(), ).fit() - def test_python_2( - self, - tensorflow_training_version, - ): + def test_mwms(self, tensorflow_training_version, tensorflow_training_py_version): + with pytest.raises(ValueError): + TensorFlow( + py_version=tensorflow_training_py_version, + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=tensorflow_training_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"multi_worker_mirrored_strategy": {"enabled": True}}, + ).fit() + + def test_python_2(self, tensorflow_training_version): with pytest.raises(ValueError): TensorFlow( py_version="py27", diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1f2f674eb7..341f5b48ae 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -25,10 +25,7 @@ from botocore.exceptions import ClientError from mock import ANY, MagicMock, Mock, patch, PropertyMock from sagemaker.huggingface.estimator import HuggingFace -from sagemaker.jumpstart.constants import ( - JUMPSTART_BUCKET_NAME_SET, - JUMPSTART_RESOURCE_BASE_NAME, -) +from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME from sagemaker.jumpstart.enums import JumpStartTag import sagemaker.local @@ -112,11 +109,7 @@ "training_steps": "100", }, "RoleArn": "arn:aws:iam::366:role/SageMakerRole", - "ResourceConfig": { - "VolumeSizeInGB": 30, - "InstanceCount": 1, - "InstanceType": "ml.c4.xlarge", - }, + "ResourceConfig": {"VolumeSizeInGB": 30, "InstanceCount": 1, "InstanceType": "ml.c4.xlarge"}, "EnableNetworkIsolation": False, "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, "TrainingJobName": "neo", @@ -145,6 +138,7 @@ LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} DISTRIBUTION_PS_ENABLED = {"parameter_server": {"enabled": True}} +DISTRIBUTION_MWMS_ENABLED = {"multi_worker_mirrored_strategy": {"enabled": True}} DISTRIBUTION_MPI_ENABLED = { "mpi": {"enabled": True, "custom_mpi_options": "options", "processes_per_host": 2} } @@ -153,10 +147,7 @@ } MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir" MOCKED_PIPELINE_CONFIG = _PipelineConfig( - "test-pipeline", - "test-training-step", - "code-hash-0123456789", - "config-hash-0123456789", + "test-pipeline", "test-training-step", "code-hash-0123456789", "config-hash-0123456789" ) @@ -261,9 +252,7 @@ def pipeline_session(): session_mock.resource.return_value = resource_mock session_mock.client.return_value = client_mock return PipelineSession( - boto_session=session_mock, - sagemaker_client=client_mock, - default_bucket=BUCKET_NAME, + boto_session=session_mock, sagemaker_client=client_mock, default_bucket=BUCKET_NAME ) @@ -338,11 +327,7 @@ def test_framework_all_init_args(sagemaker_session): }, "metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}], "encrypt_inter_container_traffic": True, - "environment": { - "env_key1": "env_val1", - "env_key2": "env_val2", - "env_key3": "env_val3", - }, + "environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}, "experiment_config": None, "checkpoint_s3_uri": "s3://bucket/checkpoint", "checkpoint_local_path": "file://local/checkpoint", @@ -463,8 +448,7 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", - parameters={"train.save_interval": "50", "eval.save_interval": "10"}, + name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} ) ], ) @@ -490,10 +474,7 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": { - "train.save_interval": "50", - "eval.save_interval": "10", - }, + "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, } ], } @@ -505,8 +486,7 @@ def test_framework_with_debugger_and_built_in_rule(sagemaker_session): def test_framework_with_debugger_and_custom_rule(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", - collection_configs=[CollectionConfig(name="weights")], + s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] ) debugger_custom_rule = Rule.custom( name="CustomRule", @@ -626,8 +606,7 @@ def test_framework_with_debugger_rule_and_multiple_actions(sagemaker_session): def test_framework_with_only_debugger_hook_config(sagemaker_session): hook_config = DebuggerHookConfig( - s3_output_path="s3://output", - collection_configs=[CollectionConfig(name="weights")], + s3_output_path="s3://output", collection_configs=[CollectionConfig(name="weights")] ) f = DummyFramework( entry_point=SCRIPT_PATH, @@ -676,8 +655,7 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): rule_parameters={"threshold": "120", "stop_training_on_fire": "True"}, collections_to_save=[ CollectionConfig( - name="losses", - parameters={"train.save_interval": "50", "eval.save_interval": "10"}, + name="losses", parameters={"train.save_interval": "50", "eval.save_interval": "10"} ) ], ) @@ -725,10 +703,7 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): "CollectionConfigurations": [ { "CollectionName": "losses", - "CollectionParameters": { - "train.save_interval": "50", - "eval.save_interval": "10", - }, + "CollectionParameters": {"train.save_interval": "50", "eval.save_interval": "10"}, } ], } @@ -740,10 +715,7 @@ def test_framework_with_debugger_and_profiler_rules(sagemaker_session): { "RuleConfigurationName": "CustomProfilerReportRule", "RuleEvaluatorImage": "895741380848.dkr.ecr.us-west-2.amazonaws.com/sagemaker-debugger-rules:latest", - "RuleParameters": { - "rule_to_invoke": "ProfilerReport", - "CPUBottleneck_threshold": "90", - }, + "RuleParameters": {"rule_to_invoke": "ProfilerReport", "CPUBottleneck_threshold": "90"}, }, { "InstanceType": "c4.4xlarge", @@ -1051,10 +1023,7 @@ def test_framework_with_enabling_default_profiling_with_existed_s3_output_path( f.enable_default_profiling() sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == { - "DisableProfiler": False, - "S3OutputPath": "s3://custom/", - } + assert args["profiler_config"] == {"DisableProfiler": False, "S3OutputPath": "s3://custom/"} def test_framework_with_disabling_profiling_when_profiler_is_already_disabled( @@ -1190,10 +1159,7 @@ def test_framework_with_disable_framework_metrics(sagemaker_session): f.update_profiler(disable_framework_metrics=True) sagemaker_session.update_training_job.assert_called_once() _, args = sagemaker_session.update_training_job.call_args - assert args["profiler_config"] == { - "DisableProfiler": False, - "ProfilingParameters": {}, - } + assert args["profiler_config"] == {"DisableProfiler": False, "ProfilingParameters": {}} assert "profiler_rule_configs" not in args @@ -1759,10 +1725,7 @@ def test_start_new_wait_called(strftime, sagemaker_session): def test_attach_framework(sagemaker_session, training_job_description): - training_job_description["VpcConfig"] = { - "Subnets": ["foo"], - "SecurityGroupIds": ["bar"], - } + training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} training_job_description["EnableNetworkIsolation"] = True framework_estimator = DummyFramework.attach( @@ -1856,8 +1819,7 @@ def test_attach_framework_with_inter_container_traffic_encryption_flag( def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): base_job_name = "neo" framework_estimator = DummyFramework.attach( - training_job_name=utils.name_from_base("neo"), - sagemaker_session=sagemaker_session, + training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session ) assert framework_estimator.base_job_name == base_job_name @@ -2052,8 +2014,7 @@ def test_git_support_bad_repo_url_format(sagemaker_session): @patch( "sagemaker.git_utils.git_clone_repo", side_effect=subprocess.CalledProcessError( - returncode=1, - cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir", + returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" ), ) def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): @@ -2078,11 +2039,7 @@ def test_git_support_git_clone_fail(git_clone_repo, sagemaker_session): ), ) def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): - git_config = { - "repo": GIT_REPO, - "branch": "branch-that-does-not-exist", - "commit": COMMIT, - } + git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -2103,11 +2060,7 @@ def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): ), ) def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): - git_config = { - "repo": GIT_REPO, - "branch": BRANCH, - "commit": "commit-sha-that-does-not-exist", - } + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} fw = DummyFramework( entry_point="entry_point", git_config=git_config, @@ -2250,11 +2203,7 @@ def test_git_support_with_token_2fa(git_clone_repo, sagemaker_session): }, ) def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session): - git_config = { - "repo": PRIVATE_GIT_REPO_SSH, - "branch": PRIVATE_BRANCH, - "commit": PRIVATE_COMMIT, - } + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2276,11 +2225,7 @@ def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session) ), ) def test_git_support_ssh_passphrase_required(git_clone_repo, sagemaker_session): - git_config = { - "repo": PRIVATE_GIT_REPO_SSH, - "branch": PRIVATE_BRANCH, - "commit": PRIVATE_COMMIT, - } + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} entry_point = "entry_point" fw = DummyFramework( entry_point=entry_point, @@ -2578,9 +2523,7 @@ def test_estimator_transformer_creation_with_optional_params(create_model, sagem ) create_model.assert_called_with( - vpc_config_override=new_vpc_config, - model_kms_key=kms_key, - enable_network_isolation=True, + vpc_config_override=new_vpc_config, model_kms_key=kms_key, enable_network_isolation=True ) assert transformer.strategy == strategy @@ -2865,11 +2808,7 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): @patch("sagemaker.estimator.name_from_base") def test_fit_deploy_tags(name_from_base, sagemaker_session): estimator = Estimator( - IMAGE_URI, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - sagemaker_session=sagemaker_session, + IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session ) estimator.fit() @@ -3318,10 +3257,7 @@ def test_generic_training_job_analytics(sagemaker_session): "TrainingInputMode": "File", "MetricDefinitions": [ {"Name": "train:loss", "Regex": "train_loss=([0-9]+\\.[0-9]+)"}, - { - "Name": "validation:loss", - "Regex": "valid_loss=([0-9]+\\.[0-9]+)", - }, + {"Name": "validation:loss", "Regex": "valid_loss=([0-9]+\\.[0-9]+)"}, ], }, }, @@ -3352,11 +3288,7 @@ def test_generic_create_model_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - sagemaker_session=sagemaker_session, + IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session ) e.fit({"train": "s3://bucket/training-prefix"}) assert e.get_vpc_config() is None @@ -3382,11 +3314,7 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]} e = Estimator( - IMAGE_URI, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - sagemaker_session=sagemaker_session, + IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE) @@ -3406,11 +3334,7 @@ def test_generic_deploy_vpc_config_override(sagemaker_session): def test_generic_deploy_accelerator_type(sagemaker_session): e = Estimator( - IMAGE_URI, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - sagemaker_session=sagemaker_session, + IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session ) e.fit({"train": "s3://bucket/training-prefix"}) e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE) @@ -3548,12 +3472,7 @@ def test_register_default_image(sagemaker_session): sagemaker_session.create_model.assert_not_called() expected_create_model_package_request = { - "containers": [ - { - "Image": estimator.image_uri, - "ModelDataUrl": estimator.model_data, - } - ], + "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], "content_types": content_types, "response_types": response_types, "inference_instances": inference_instances, @@ -3602,12 +3521,7 @@ def test_register_default_image_without_instance_type_args(sagemaker_session): sagemaker_session.create_model.assert_not_called() expected_create_model_package_request = { - "containers": [ - { - "Image": estimator.image_uri, - "ModelDataUrl": estimator.model_data, - } - ], + "containers": [{"Image": estimator.image_uri, "ModelDataUrl": estimator.model_data}], "content_types": content_types, "response_types": response_types, "inference_instances": None, @@ -3662,12 +3576,7 @@ def test_register_inference_image(sagemaker_session): sagemaker_session.create_model.assert_not_called() expected_create_model_package_request = { - "containers": [ - { - "Image": inference_image, - "ModelDataUrl": estimator.model_data, - } - ], + "containers": [{"Image": inference_image, "ModelDataUrl": estimator.model_data}], "content_types": content_types, "response_types": response_types, "inference_instances": inference_instances, @@ -3754,13 +3663,7 @@ def test_file_output_path_not_supported_outside_local_mode(session_class): session_class.return_value = session with pytest.raises(RuntimeError): - Estimator( - IMAGE_URI, - ROLE, - INSTANCE_COUNT, - INSTANCE_TYPE, - output_path="file:///tmp/model", - ) + Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path="file:///tmp/model") def test_prepare_init_params_from_job_description_with_image_training_job(): @@ -3854,7 +3757,6 @@ def test_prepare_init_params_from_job_description_with_training_image_config(): def test_prepare_init_params_from_job_description_with_invalid_training_job(): - invalid_job_description = RETURNED_JOB_DESCRIPTION.copy() invalid_job_description["AlgorithmSpecification"] = {"TrainingInputMode": "File"} @@ -3891,10 +3793,7 @@ def test_prepare_for_training_with_name_based_on_image(sagemaker_session): @patch("sagemaker.algorithm.AlgorithmEstimator.validate_train_spec", Mock()) -@patch( - "sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", - Mock(return_value={}), -) +@patch("sagemaker.algorithm.AlgorithmEstimator._parse_hyperparameters", Mock(return_value={})) def test_prepare_for_training_with_name_based_on_algorithm(sagemaker_session): estimator = AlgorithmEstimator( algorithm_arn="arn:aws:sagemaker:us-west-2:1234:algorithm/scikit-decision-trees-1542410022", @@ -4018,6 +3917,21 @@ def test_framework_distribution_configuration(sagemaker_session): assert actual_ddp == expected_ddp +def test_mwms_distribution_configuration(sagemaker_session): + framework = DummyFramework( + entry_point="script", + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + ) + with pytest.raises(ValueError) as error: + framework._distribution_configuration(distribution=DISTRIBUTION_MWMS_ENABLED) + + assert "only supported with" in str(error) + assert "but received" in str(error) + + def test_image_name_map(sagemaker_session): e = DummyFramework( "my_script.py", @@ -4084,7 +3998,6 @@ def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): def test_script_mode_estimator_same_calls_as_framework( patched_tar_and_upload_dir, sagemaker_session ): - patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" ) @@ -4182,14 +4095,8 @@ def test_script_mode_estimator_tags_jumpstart_estimators_and_models( assert [ {"Key": "some", "Value": "tag"}, - { - "Key": JumpStartTag.TRAINING_MODEL_URI.value, - "Value": jumpstart_source_dir_2, - }, - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": jumpstart_source_dir_2}, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir}, ] == sagemaker_session.train.call_args_list[0][1]["tags"] sagemaker_session.reset_mock() @@ -4213,33 +4120,15 @@ def test_script_mode_estimator_tags_jumpstart_estimators_and_models( assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ {"Key": "deploys", "Value": "tag"}, - { - "Key": JumpStartTag.TRAINING_MODEL_URI.value, - "Value": jumpstart_source_dir_2, - }, - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": inference_jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": jumpstart_source_dir_2}, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_jumpstart_source_dir}, ] assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ {"Key": "deploys", "Value": "tag"}, - { - "Key": JumpStartTag.TRAINING_MODEL_URI.value, - "Value": jumpstart_source_dir_2, - }, - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": inference_jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": jumpstart_source_dir_2}, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_jumpstart_source_dir}, ] @@ -4275,10 +4164,7 @@ def test_script_mode_estimator_tags_jumpstart_models( generic_estimator.fit(training_data_uri) assert [ - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir} ] == sagemaker_session.train.call_args_list[0][1]["tags"] sagemaker_session.reset_mock() @@ -4298,16 +4184,10 @@ def test_script_mode_estimator_tags_jumpstart_models( ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir} ] assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.TRAINING_SCRIPT_URI.value, - "Value": jumpstart_source_dir, - }, + {"Key": JumpStartTag.TRAINING_SCRIPT_URI.value, "Value": jumpstart_source_dir} ] @@ -4363,16 +4243,10 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": inference_jumpstart_source_dir, - }, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_jumpstart_source_dir} ] assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": inference_jumpstart_source_dir, - }, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": inference_jumpstart_source_dir} ] @@ -4381,12 +4255,8 @@ def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_tags( - patched_repack_model, - patched_upload_code, - patched_tar_and_upload_dir, - sagemaker_session, + patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session ): - sagemaker_session.boto_region_name = REGION sagemaker_session.sagemaker_client.describe_training_job.return_value = { "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} @@ -4413,20 +4283,13 @@ def test_all_framework_estimators_add_jumpstart_tags( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: { - "framework_version": "1.7.0", - "py_version": "py3", - "instance_type": "ml.p2.xlarge", - }, + MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for ( - framework_estimator_class, - kwargs, - ) in framework_estimator_classes_to_kwargs.items(): + for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE, @@ -4453,24 +4316,12 @@ def test_all_framework_estimators_add_jumpstart_tags( ) assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.TRAINING_MODEL_URI.value, - "Value": jumpstart_model_uri, - }, - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": jumpstart_model_uri_2, - }, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": jumpstart_model_uri}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": jumpstart_model_uri_2}, ] assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [ - { - "Key": JumpStartTag.TRAINING_MODEL_URI.value, - "Value": jumpstart_model_uri, - }, - { - "Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, - "Value": jumpstart_model_uri_2, - }, + {"Key": JumpStartTag.TRAINING_MODEL_URI.value, "Value": jumpstart_model_uri}, + {"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value, "Value": jumpstart_model_uri_2}, ] sagemaker_session.train.reset_mock() @@ -4542,12 +4393,8 @@ def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( @patch("sagemaker.model.Model._upload_code") @patch("sagemaker.utils.repack_model") def test_all_framework_estimators_add_jumpstart_base_name( - patched_repack_model, - patched_upload_code, - patched_tar_and_upload_dir, - sagemaker_session, + patched_repack_model, patched_upload_code, patched_tar_and_upload_dir, sagemaker_session ): - sagemaker_session.boto_region_name = REGION sagemaker_session.sagemaker_client.describe_training_job.return_value = { "ModelArtifacts": {"S3ModelArtifacts": "some-uri"} @@ -4574,20 +4421,13 @@ def test_all_framework_estimators_add_jumpstart_base_name( "transformers_version": "4.6.1", "instance_type": "ml.p2.xlarge", }, - MXNet: { - "framework_version": "1.7.0", - "py_version": "py3", - "instance_type": "ml.p2.xlarge", - }, + MXNet: {"framework_version": "1.7.0", "py_version": "py3", "instance_type": "ml.p2.xlarge"}, SKLearn: {"framework_version": "0.23-1", "instance_type": "ml.m2.xlarge"}, XGBoost: {"framework_version": "1.3-1", "instance_type": "ml.m2.xlarge"}, } jumpstart_model_uri = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/model_dirs/model.tar.gz" jumpstart_model_uri_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/model_dirs/model.tar.gz" - for ( - framework_estimator_class, - kwargs, - ) in framework_estimator_classes_to_kwargs.items(): + for framework_estimator_class, kwargs in framework_estimator_classes_to_kwargs.items(): estimator = framework_estimator_class( entry_point=ENTRY_POINT, role=ROLE,