diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 8ed9b724a5..29e0d250aa 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -155,6 +155,8 @@ def __init__( entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[Union[str]]] = None, instance_groups: Optional[List[InstanceGroup]] = None, + training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None, + training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -489,6 +491,18 @@ def __init__( `Train Using a Heterogeneous Cluster `_ in the *Amazon SageMaker developer guide*. + training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the + Docker image that contains the training algorithm (default: None). + Set this to one of the following values: + * 'Platform' - The training image is hosted in Amazon ECR. + * 'Vpc' - The training image is hosted in a private Docker registry in your VPC. + When it's default to None, its behavior will be same as 'Platform' - image is hosted + in ECR. + training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name + (ARN) of an AWS Lambda function that provides credentials to authenticate to the + private Docker registry where your training image is hosted (default: None). + When it's set to None, SageMaker will not do authentication before pulling the image + in the private Docker registry. """ instance_count = renamed_kwargs( "train_instance_count", "instance_count", instance_count, kwargs @@ -536,7 +550,9 @@ def __init__( self.dependencies = dependencies or [] self.uploaded_code = None self.tags = add_jumpstart_tags( - tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir + tags=tags, + training_model_uri=self.model_uri, + training_script_uri=self.source_dir, ) if self.instance_type in ("local", "local_gpu"): if self.instance_type == "local_gpu" and self.instance_count > 1: @@ -571,6 +587,12 @@ def __init__( self.subnets = subnets self.security_group_ids = security_group_ids + # training image configs + self.training_repository_access_mode = training_repository_access_mode + self.training_repository_credentials_provider_arn = ( + training_repository_credentials_provider_arn + ) + self.encrypt_inter_container_traffic = encrypt_inter_container_traffic self.use_spot_instances = use_spot_instances self.max_wait = max_wait @@ -651,7 +673,8 @@ def _ensure_base_job_name(self): self.base_job_name or get_jumpstart_base_name_if_jumpstart_model(self.source_dir, self.model_uri) or base_name_from_image( - self.training_image_uri(), default_base_name=EstimatorBase.JOB_CLASS_NAME + self.training_image_uri(), + default_base_name=EstimatorBase.JOB_CLASS_NAME, ) ) @@ -1405,7 +1428,10 @@ def deploy( self._ensure_base_job_name() jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model( - kwargs.get("source_dir"), self.source_dir, kwargs.get("model_data"), self.model_uri + kwargs.get("source_dir"), + self.source_dir, + kwargs.get("model_data"), + self.model_uri, ) default_name = ( name_from_base(jumpstart_base_name) @@ -1638,6 +1664,15 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"] elif "TrainingImage" in job_details["AlgorithmSpecification"]: init_params["image_uri"] = job_details["AlgorithmSpecification"]["TrainingImage"] + if "TrainingImageConfig" in job_details["AlgorithmSpecification"]: + init_params["training_repository_access_mode"] = job_details[ + "AlgorithmSpecification" + ]["TrainingImageConfig"].get("TrainingRepositoryAccessMode") + init_params["training_repository_credentials_provider_arn"] = ( + job_details["AlgorithmSpecification"]["TrainingImageConfig"] + .get("TrainingRepositoryAuthConfig", {}) + .get("TrainingRepositoryCredentialsProviderArn") + ) else: raise RuntimeError( "Invalid AlgorithmSpecification. Either TrainingImage or " @@ -2118,6 +2153,17 @@ def _get_train_args(cls, estimator, inputs, experiment_config): else: train_args["retry_strategy"] = None + if estimator.training_repository_access_mode is not None: + training_image_config = { + "TrainingRepositoryAccessMode": estimator.training_repository_access_mode + } + if estimator.training_repository_credentials_provider_arn is not None: + training_image_config["TrainingRepositoryAuthConfig"] = {} + training_image_config["TrainingRepositoryAuthConfig"][ + "TrainingRepositoryCredentialsProviderArn" + ] = estimator.training_repository_credentials_provider_arn + train_args["training_image_config"] = training_image_config + # encrypt_inter_container_traffic may be a pipeline variable place holder object # which is parsed in execution time if estimator.encrypt_inter_container_traffic: @@ -2182,7 +2228,11 @@ def _is_local_channel(cls, input_uri): @classmethod def update( - cls, estimator, profiler_rule_configs=None, profiler_config=None, resource_config=None + cls, + estimator, + profiler_rule_configs=None, + profiler_config=None, + resource_config=None, ): """Update a running Amazon SageMaker training job. @@ -2321,6 +2371,8 @@ def __init__( entry_point: Optional[Union[str, PipelineVariable]] = None, dependencies: Optional[List[str]] = None, instance_groups: Optional[List[InstanceGroup]] = None, + training_repository_access_mode: Optional[Union[str, PipelineVariable]] = None, + training_repository_credentials_provider_arn: Optional[Union[str, PipelineVariable]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2654,6 +2706,18 @@ def __init__( `Train Using a Heterogeneous Cluster `_ in the *Amazon SageMaker developer guide*. + training_repository_access_mode (str): Optional. Specifies how SageMaker accesses the + Docker image that contains the training algorithm (default: None). + Set this to one of the following values: + * 'Platform' - The training image is hosted in Amazon ECR. + * 'Vpc' - The training image is hosted in a private Docker registry in your VPC. + When it's default to None, its behavior will be same as 'Platform' - image is hosted + in ECR. + training_repository_credentials_provider_arn (str): Optional. The Amazon Resource Name + (ARN) of an AWS Lambda function that provides credentials to authenticate to the + private Docker registry where your training image is hosted (default: None). + When it's set to None, SageMaker will not do authentication before pulling the image + in the private Docker registry. """ self.image_uri = image_uri self._hyperparameters = hyperparameters.copy() if hyperparameters else {} @@ -2698,6 +2762,8 @@ def __init__( dependencies=dependencies, hyperparameters=hyperparameters, instance_groups=instance_groups, + training_repository_access_mode=training_repository_access_mode, + training_repository_credentials_provider_arn=training_repository_credentials_provider_arn, # noqa: E501 # pylint: disable=line-too-long **kwargs, ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 0df2996352..f7a03202f5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -488,6 +488,7 @@ def train( # noqa: C901 metric_definitions, enable_network_isolation=False, image_uri=None, + training_image_config=None, algorithm_arn=None, encrypt_inter_container_traffic=False, use_spot_instances=False, @@ -548,6 +549,28 @@ def train( # noqa: C901 enable_network_isolation (bool): Whether to request for the training job to run with network isolation or not. image_uri (str): Docker image containing training code. + training_image_config(dict): Training image configuration. + Optionally, the dict can contain 'TrainingRepositoryAccessMode' and + 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). + For example, + + .. code:: python + + training_image_config = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": + "arn:aws:lambda:us-west-2:1234567890:function:test" + }, + } + + If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed + through a private Docker registry in customer Vpc. If it's set to Platform or None, + the training image is accessed through ECR. + If TrainingRepositoryCredentialsProviderArn is provided, the credentials to + authenticate to the private Docker registry will be retrieved from this AWS Lambda + function. (default: ``None``). When it's set to None, SageMaker will not do + authentication before pulling the image in the private Docker registry. algorithm_arn (str): Algorithm Arn from Marketplace. encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). @@ -606,6 +629,7 @@ def train( # noqa: C901 metric_definitions=metric_definitions, enable_network_isolation=enable_network_isolation, image_uri=image_uri, + training_image_config=training_image_config, algorithm_arn=algorithm_arn, encrypt_inter_container_traffic=encrypt_inter_container_traffic, use_spot_instances=use_spot_instances, @@ -644,6 +668,7 @@ def _get_train_request( # noqa: C901 metric_definitions, enable_network_isolation=False, image_uri=None, + training_image_config=None, algorithm_arn=None, encrypt_inter_container_traffic=False, use_spot_instances=False, @@ -704,6 +729,28 @@ def _get_train_request( # noqa: C901 enable_network_isolation (bool): Whether to request for the training job to run with network isolation or not. image_uri (str): Docker image containing training code. + training_image_config(dict): Training image configuration. + Optionally, the dict can contain 'TrainingRepositoryAccessMode' and + 'TrainingRepositoryCredentialsProviderArn' (under 'TrainingRepositoryAuthConfig'). + For example, + + .. code:: python + + training_image_config = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": + "arn:aws:lambda:us-west-2:1234567890:function:test" + }, + } + + If TrainingRepositoryAccessMode is set to Vpc, the training image is accessed + through a private Docker registry in customer Vpc. If it's set to Platform or None, + the training image is accessed through ECR. + If TrainingRepositoryCredentialsProviderArn is provided, the credentials to + authenticate to the private Docker registry will be retrieved from this AWS Lambda + function. (default: ``None``). When it's set to None, SageMaker will not do + authentication before pulling the image in the private Docker registry. algorithm_arn (str): Algorithm Arn from Marketplace. encrypt_inter_container_traffic (bool): Specifies whether traffic between training containers is encrypted for the training job (default: ``False``). @@ -768,6 +815,9 @@ def _get_train_request( # noqa: C901 if image_uri is not None: train_request["AlgorithmSpecification"]["TrainingImage"] = image_uri + if training_image_config is not None: + train_request["AlgorithmSpecification"]["TrainingImageConfig"] = training_image_config + if algorithm_arn is not None: train_request["AlgorithmSpecification"]["AlgorithmName"] = algorithm_arn diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 45a944ce1a..687b24dff7 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -93,6 +93,8 @@ CODECOMMIT_BRANCH = "master" REPO_DIR = "/tmp/repo_dir" ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} +TRAINING_REPOSITORY_ACCESS_MODE = "VPC" +TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN = "arn:aws:lambda:us-west-2:1234567890:function:test" DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} @@ -391,6 +393,70 @@ def test_framework_with_keep_alive_period(sagemaker_session): assert args["resource_config"]["KeepAlivePeriodInSeconds"] == KEEP_ALIVE_PERIOD_IN_SECONDS +def test_framework_with_both_training_repository_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE, + training_repository_credentials_provider_arn=TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert ( + args["training_image_config"]["TrainingRepositoryAccessMode"] + == TRAINING_REPOSITORY_ACCESS_MODE + ) + assert ( + args["training_image_config"]["TrainingRepositoryAuthConfig"][ + "TrainingRepositoryCredentialsProviderArn" + ] + == TRAINING_REPOSITORY_CREDENTIALS_PROVIDER_ARN + ) + + +def test_framework_with_training_repository_access_mode(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + training_repository_access_mode=TRAINING_REPOSITORY_ACCESS_MODE, + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert ( + args["training_image_config"]["TrainingRepositoryAccessMode"] + == TRAINING_REPOSITORY_ACCESS_MODE + ) + assert "TrainingRepositoryAuthConfig" not in args["training_image_config"] + + +def test_framework_without_training_repository_config(sagemaker_session): + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + sagemaker_session.train.assert_called_once() + _, args = sagemaker_session.train.call_args + assert args.get("training_image_config") is None + + def test_framework_with_debugger_and_built_in_rule(sagemaker_session): debugger_built_in_rule_with_custom_args = Rule.sagemaker( base_config=rule_configs.stalled_training_rule(), @@ -3763,6 +3829,28 @@ def test_prepare_init_params_from_job_description_with_retry_strategy(): assert init_params["max_retry_attempts"] == 2 +def test_prepare_init_params_from_job_description_with_training_image_config(): + job_description = RETURNED_JOB_DESCRIPTION.copy() + job_description["AlgorithmSpecification"]["TrainingImageConfig"] = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567890:function:test" + }, + } + + init_params = EstimatorBase._prepare_init_params_from_job_description( + job_details=job_description + ) + + assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole" + assert init_params["instance_count"] == 1 + assert init_params["training_repository_access_mode"] == "Vpc" + assert ( + init_params["training_repository_credentials_provider_arn"] + == "arn:aws:lambda:us-west-2:1234567890:function:test" + ) + + def test_prepare_init_params_from_job_description_with_invalid_training_job(): invalid_job_description = RETURNED_JOB_DESCRIPTION.copy() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 4f951dfcfe..54a024bdc4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -296,7 +296,10 @@ def test_get_execution_role_throws_exception_if_arn_is_not_role_with_role_in_nam assert "The current AWS identity is not a role" in str(error.value) -@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}')) +@patch( + "six.moves.builtins.open", + mock_open(read_data='{"ResourceName": "SageMakerInstance"}'), +) @patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) def test_get_caller_identity_arn_from_describe_notebook_instance(boto_session): sess = Session(boto_session) @@ -419,7 +422,10 @@ def test_get_caller_identity_arn_from_describe_domain_for_space(boto_session): sess.sagemaker_client.describe_domain.assert_called_once_with(DomainId="d-kbnw5yk6tg8j") -@patch("six.moves.builtins.open", mock_open(read_data='{"ResourceName": "SageMakerInstance"}')) +@patch( + "six.moves.builtins.open", + mock_open(read_data='{"ResourceName": "SageMakerInstance"}'), +) @patch("os.path.exists", side_effect=mock_exists(NOTEBOOK_METADATA_FILE, True)) @patch("sagemaker.session.sts_regional_endpoint", return_value=STS_ENDPOINT) def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception( @@ -427,7 +433,8 @@ def test_get_caller_identity_arn_from_a_role_after_describe_notebook_exception( ): sess = Session(boto_session) exception = ClientError( - {"Error": {"Code": "ValidationException", "Message": "RecordNotFound"}}, "Operation" + {"Error": {"Code": "ValidationException", "Message": "RecordNotFound"}}, + "Operation", ) sess.sagemaker_client.describe_notebook_instance.side_effect = exception @@ -771,10 +778,17 @@ def test_training_input_all_arguments(): "TransformJobStatus": "Completed", "ModelName": "some-model", "TransformJobName": JOB_NAME, - "TransformResources": {"InstanceCount": INSTANCE_COUNT, "InstanceType": INSTANCE_TYPE}, + "TransformResources": { + "InstanceCount": INSTANCE_COUNT, + "InstanceType": INSTANCE_TYPE, + }, "TransformEndTime": datetime.datetime(2018, 2, 17, 7, 19, 34, 953000), "TransformStartTime": datetime.datetime(2018, 2, 17, 7, 15, 0, 103000), - "TransformOutput": {"AssembleWith": "None", "KmsKeyId": "", "S3OutputPath": S3_OUTPUT}, + "TransformOutput": { + "AssembleWith": "None", + "KmsKeyId": "", + "S3OutputPath": S3_OUTPUT, + }, "TransformInput": { "CompressionType": "None", "ContentType": "text/csv", @@ -894,7 +908,10 @@ def test_train_pack_to_request(sagemaker_session): "HyperParameterTuningJobConfig": { "Strategy": "Bayesian", "HyperParameterTuningJobObjective": SAMPLE_OBJECTIVE, - "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, + "ResourceLimits": { + "MaxNumberOfTrainingJobs": 100, + "MaxParallelTrainingJobs": 5, + }, "ParameterRanges": SAMPLE_PARAM_RANGES, "TrainingJobEarlyStoppingType": "Off", "RandomSeed": 0, @@ -918,7 +935,10 @@ def test_train_pack_to_request(sagemaker_session): "HyperParameterTuningJobName": "dummy-tuning-1", "HyperParameterTuningJobConfig": { "Strategy": "Bayesian", - "ResourceLimits": {"MaxNumberOfTrainingJobs": 100, "MaxParallelTrainingJobs": 5}, + "ResourceLimits": { + "MaxNumberOfTrainingJobs": 100, + "MaxParallelTrainingJobs": 5, + }, "TrainingJobEarlyStoppingType": "Off", }, "TrainingJobDefinitions": [ @@ -967,7 +987,10 @@ def test_train_pack_to_request(sagemaker_session): @pytest.mark.parametrize( "warm_start_type, parents", - [("IdenticalDataAndAlgorithm", {"p1", "p2", "p3"}), ("TransferLearning", {"p1", "p2", "p3"})], + [ + ("IdenticalDataAndAlgorithm", {"p1", "p2", "p3"}), + ("TransferLearning", {"p1", "p2", "p3"}), + ], ) def test_tune_warm_start(sagemaker_session, warm_start_type, parents): def assert_create_tuning_job_request(**kwrags): @@ -1014,7 +1037,8 @@ def assert_create_tuning_job_request(**kwrags): def test_create_tuning_job_without_training_config_or_list(sagemaker_session): with pytest.raises( - ValueError, match="Either training_config or training_config_list should be provided." + ValueError, + match="Either training_config or training_config_list should be provided.", ): sagemaker_session.create_tuning_job( job_name="dummy-tuning-1", @@ -1031,7 +1055,8 @@ def test_create_tuning_job_without_training_config_or_list(sagemaker_session): def test_create_tuning_job_with_both_training_config_and_list(sagemaker_session): with pytest.raises( - ValueError, match="Only one of training_config and training_config_list should be provided." + ValueError, + match="Only one of training_config and training_config_list should be provided.", ): sagemaker_session.create_tuning_job( job_name="dummy-tuning-1", @@ -1043,7 +1068,10 @@ def test_create_tuning_job_with_both_training_config_and_list(sagemaker_session) "max_parallel_jobs": 5, "parameter_ranges": SAMPLE_PARAM_RANGES, }, - training_config={"static_hyperparameters": STATIC_HPs, "image_uri": "dummy-image-1"}, + training_config={ + "static_hyperparameters": STATIC_HPs, + "image_uri": "dummy-image-1", + }, training_config_list=[ { "static_hyperparameters": STATIC_HPs, @@ -1389,6 +1417,12 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): stop_cond = {"MaxRuntimeInSeconds": MAX_TIME} RETRY_STRATEGY = {"MaximumRetryAttempts": 2} hyperparameters = {"foo": "bar"} + TRAINING_IMAGE_CONFIG = { + "TrainingRepositoryAccessMode": "Vpc", + "TrainingRepositoryAuthConfig": { + "TrainingRepositoryCredentialsProviderArn": "arn:aws:lambda:us-west-2:1234567897:function:test" + }, + } sagemaker_session.train( image_uri=IMAGE, @@ -1410,6 +1444,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): enable_sagemaker_metrics=True, environment=ENV_INPUT, retry_strategy=RETRY_STRATEGY, + training_image_config=TRAINING_IMAGE_CONFIG, ) _, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0] @@ -1425,6 +1460,9 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session): assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints" assert actual_train_args["Environment"] == ENV_INPUT assert actual_train_args["RetryStrategy"] == RETRY_STRATEGY + assert ( + actual_train_args["AlgorithmSpecification"]["TrainingImageConfig"] == TRAINING_IMAGE_CONFIG + ) def test_transform_pack_to_request(sagemaker_session): @@ -1765,7 +1803,9 @@ def test_create_model(expand_container_def, sagemaker_session): assert model == MODEL_NAME sagemaker_session.sagemaker_client.create_model.assert_called_with( - ExecutionRoleArn=EXPANDED_ROLE, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER, ) @@ -1790,7 +1830,9 @@ def test_create_model_with_primary_container(expand_container_def, sagemaker_ses assert model == MODEL_NAME sagemaker_session.sagemaker_client.create_model.assert_called_with( - ExecutionRoleArn=EXPANDED_ROLE, ModelName=MODEL_NAME, PrimaryContainer=PRIMARY_CONTAINER + ExecutionRoleArn=EXPANDED_ROLE, + ModelName=MODEL_NAME, + PrimaryContainer=PRIMARY_CONTAINER, ) @@ -1798,7 +1840,10 @@ def test_create_model_with_primary_container(expand_container_def, sagemaker_ses def test_create_model_with_both(expand_container_def, sagemaker_session): with pytest.raises(ValueError): sagemaker_session.create_model( - MODEL_NAME, ROLE, container_defs=PRIMARY_CONTAINER, primary_container=PRIMARY_CONTAINER + MODEL_NAME, + ROLE, + container_defs=PRIMARY_CONTAINER, + primary_container=PRIMARY_CONTAINER, ) @@ -1857,7 +1902,10 @@ def test_create_pipeline_model_vpc_config(expand_container_def, sagemaker_sessio @patch("sagemaker.session._expand_container_def", return_value=PRIMARY_CONTAINER) def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog): error_response = { - "Error": {"Code": "ValidationException", "Message": "Cannot create already existing model"} + "Error": { + "Code": "ValidationException", + "Message": "Cannot create already existing model", + } } exception = ClientError(error_response, "Operation") sagemaker_session.sagemaker_client.create_model.side_effect = exception @@ -1961,7 +2009,13 @@ def test_endpoint_from_production_variants(sagemaker_session): sagemaker.production_variant("B", "p299.4096xlarge"), ] ex = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + { + "Error": { + "Code": "ValidationException", + "Message": "Could not find your thing", + } + }, + "b", ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) sagemaker_session.endpoint_from_production_variants("some-endpoint", pvs) @@ -1991,7 +2045,13 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session): sagemaker.production_variant("B", "p299.4096xlarge"), ] ex = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + { + "Error": { + "Code": "ValidationException", + "Message": "Could not find your thing", + } + }, + "b", ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) tags = [{"ModelName": "TestModel"}] @@ -2012,7 +2072,13 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi sagemaker.production_variant("B", "p299.4096xlarge", accelerator_type=ACCELERATOR_TYPE), ] ex = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + { + "Error": { + "Code": "ValidationException", + "Message": "Could not find your thing", + } + }, + "b", ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) tags = [{"ModelName": "TestModel"}] @@ -2025,7 +2091,9 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi ) -def test_endpoint_from_production_variants_with_serverless_inference_config(sagemaker_session): +def test_endpoint_from_production_variants_with_serverless_inference_config( + sagemaker_session, +): ims = sagemaker_session ims.sagemaker_client.describe_endpoint = Mock(return_value={"EndpointStatus": "InService"}) pvs = [ @@ -2033,11 +2101,19 @@ def test_endpoint_from_production_variants_with_serverless_inference_config(sage "A", "ml.p2.xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG ), sagemaker.production_variant( - "B", "p299.4096xlarge", serverless_inference_config=SERVERLESS_INFERENCE_CONFIG + "B", + "p299.4096xlarge", + serverless_inference_config=SERVERLESS_INFERENCE_CONFIG, ), ] ex = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + { + "Error": { + "Code": "ValidationException", + "Message": "Could not find your thing", + } + }, + "b", ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) tags = [{"ModelName": "TestModel"}] @@ -2058,7 +2134,13 @@ def test_endpoint_from_production_variants_with_async_config(sagemaker_session): sagemaker.production_variant("B", "p299.4096xlarge"), ] ex = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find your thing"}}, "b" + { + "Error": { + "Code": "ValidationException", + "Message": "Could not find your thing", + } + }, + "b", ) ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex) sagemaker_session.endpoint_from_production_variants( @@ -2100,7 +2182,8 @@ def test_update_endpoint_no_wait(sagemaker_session): def test_update_endpoint_non_existing_endpoint(sagemaker_session): error = ClientError( - {"Error": {"Code": "ValidationException", "Message": "Could not find entity"}}, "foo" + {"Error": {"Code": "ValidationException", "Message": "Could not find entity"}}, + "foo", ) expected_error_message = ( "Endpoint with name 'non-existing-endpoint' does not exist; " @@ -2145,7 +2228,8 @@ def test_create_endpoint_config_from_existing(sagemaker_session): def test_wait_for_tuning_job(sleep, sagemaker_session): hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + name="describe_hyper_parameter_tuning_job", + return_value=hyperparameter_tuning_job_desc, ) result = sagemaker_session.wait_for_tuning_job(JOB_NAME) @@ -2155,7 +2239,8 @@ def test_wait_for_tuning_job(sleep, sagemaker_session): def test_tune_job_status(sagemaker_session): hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "Completed"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + name="describe_hyper_parameter_tuning_job", + return_value=hyperparameter_tuning_job_desc, ) result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME) @@ -2166,7 +2251,8 @@ def test_tune_job_status(sagemaker_session): def test_tune_job_status_none(sagemaker_session): hyperparameter_tuning_job_desc = {"HyperParameterTuningJobStatus": "InProgress"} sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( - name="describe_hyper_parameter_tuning_job", return_value=hyperparameter_tuning_job_desc + name="describe_hyper_parameter_tuning_job", + return_value=hyperparameter_tuning_job_desc, ) result = _tuning_job_status(sagemaker_session.sagemaker_client, JOB_NAME) @@ -2307,7 +2393,10 @@ def test_train_done_in_progress(sagemaker_session): "SecurityConfig": { "VolumeKmsKeyId": "volume-kms-key-id-string", "EnableInterContainerTrafficEncryption": False, - "VpcConfig": {"SecurityGroupIds": ["security-group-id"], "Subnets": ["subnet"]}, + "VpcConfig": { + "SecurityGroupIds": ["security-group-id"], + "Subnets": ["subnet"], + }, }, }, "RoleArn": EXPANDED_ROLE, @@ -2393,7 +2482,10 @@ def test_auto_ml_pack_to_request_with_optional_args(sagemaker_session): "SecurityConfig": { "VolumeKmsKeyId": "volume-kms-key-id-string", "EnableInterContainerTrafficEncryption": False, - "VpcConfig": {"SecurityGroupIds": ["security-group-id"], "Subnets": ["subnet"]}, + "VpcConfig": { + "SecurityGroupIds": ["security-group-id"], + "Subnets": ["subnet"], + }, }, } @@ -2499,7 +2591,9 @@ def test_create_model_package_from_containers_incomplete_args(sagemaker_session) ) -def test_create_model_package_from_containers_without_model_package_group_name(sagemaker_session): +def test_create_model_package_from_containers_without_model_package_group_name( + sagemaker_session, +): model_package_name = "sagemaker-model-package" containers = ["dummy-container"] content_types = ["application/json"] @@ -2654,7 +2748,9 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s sagemaker_session.sagemaker_client.create_model_package.assert_called_with(**expected_args) -def test_create_model_package_from_containers_with_one_instance_types(sagemaker_session): +def test_create_model_package_from_containers_with_one_instance_types( + sagemaker_session, +): model_package_group_name = "sagemaker-model-package-group-name-1.0" containers = ["dummy-container"] content_types = ["application/json"]