diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 7b16e3cba3..36cb920dde 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -24,6 +24,7 @@ from textwrap import dedent from typing import Dict, List, Optional, Union from copy import copy +import re import attr @@ -1658,6 +1659,7 @@ def run( # type: ignore[override] job_name: Optional[str] = None, experiment_config: Optional[Dict[str, str]] = None, kms_key: Optional[str] = None, + codeartifact_repo_arn: Optional[str] = None, ): """Runs a processing job. @@ -1758,12 +1760,21 @@ def run( # type: ignore[override] However, the value of `TrialComponentDisplayName` is honored for display in Studio. kms_key (str): The ARN of the KMS key that is used to encrypt the user code file (default: None). + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). Returns: None or pipeline step arguments in case the Processor instance is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession` """ s3_runproc_sh, inputs, job_name = self._pack_and_upload_code( - code, source_dir, dependencies, git_config, job_name, inputs, kms_key + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key, + codeartifact_repo_arn, ) # Submit a processing job. @@ -1780,7 +1791,15 @@ def run( # type: ignore[override] ) def _pack_and_upload_code( - self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None + self, + code, + source_dir, + dependencies, + git_config, + job_name, + inputs, + kms_key=None, + codeartifact_repo_arn=None, ): """Pack local code bundle and upload to Amazon S3.""" if code.startswith("s3://"): @@ -1821,12 +1840,53 @@ def _pack_and_upload_code( script = estimator.uploaded_code.script_name evaluated_kms_key = kms_key if kms_key else self.output_kms_key s3_runproc_sh = self._create_and_upload_runproc( - script, evaluated_kms_key, entrypoint_s3_uri + script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn ) return s3_runproc_sh, inputs, job_name - def _generate_framework_script(self, user_script: str) -> str: + def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str: + """Build an AWS CLI CodeArtifact command to configure pip. + + The codeartifact_repo_arn property must follow the form + # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}` + https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html + https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies + + Args: + codeartifact_repo_arn: arn of the codeartifact repository + Returns: + codeartifact command string + """ + + arn_regex = ( + "arn:(?P[^:]+):codeartifact:(?P[^:]+):(?P[^:]+)" + ":repository/(?P[^/]+)/(?P.+)" + ) + m = re.match(arn_regex, codeartifact_repo_arn) + if not m: + raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn)) + domain = m.group("domain") + owner = m.group("account") + repository = m.group("repository") + region = m.group("region") + + logger.info( + "configuring pip to use codeartifact " + "(domain: %s, domain owner: %s, repository: %s, region: %s)", + domain, + owner, + repository, + region, + ) + + return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}".format( # noqa: E501 pylint: disable=line-too-long + domain, owner, repository, region + ) + + def _generate_framework_script( + self, user_script: str, codeartifact_repo_arn: str = None + ) -> str: """Generate the framework entrypoint file (as text) for a processing job. This script implements the "framework" functionality for setting up your code: @@ -1837,7 +1897,16 @@ def _generate_framework_script(self, user_script: str) -> str: Args: user_script (str): Relative path to ```code``` in the source bundle - e.g. 'process.py'. + codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be + logged into before installing dependencies (default: None). """ + if codeartifact_repo_arn: + codeartifact_login_command = self._get_codeartifact_command(codeartifact_repo_arn) + else: + codeartifact_login_command = ( + "echo 'CodeArtifact repository not specified. Skipping login.'" + ) + return dedent( """\ #!/bin/bash @@ -1849,6 +1918,13 @@ def _generate_framework_script(self, user_script: str) -> str: set -e if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + {codeartifact_login_command} + fi + # Some py3 containers has typing, which may breaks pip install pip uninstall --yes typing @@ -1858,6 +1934,7 @@ def _generate_framework_script(self, user_script: str) -> str: {entry_point_command} {entry_point} "$@" """ ).format( + codeartifact_login_command=codeartifact_login_command, entry_point_command=" ".join(self.command), entry_point=user_script, ) @@ -1933,7 +2010,9 @@ def _set_entrypoint(self, command, user_script_name): ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] - def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + def _create_and_upload_runproc( + self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None + ): """Create runproc shell script and upload to S3 bucket. If leveraging a pipeline session with optimized S3 artifact paths, @@ -1949,7 +2028,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): from sagemaker.workflow.utilities import _pipeline_config, hash_object if _pipeline_config and _pipeline_config.pipeline_name: - runproc_file_str = self._generate_framework_script(user_script) + runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn) runproc_file_hash = hash_object(runproc_file_str) s3_uri = s3.s3_path_join( "s3://", @@ -1968,7 +2047,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): ) else: s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(user_script), + self._generate_framework_script(user_script, codeartifact_repo_arn), desired_s3_uri=entrypoint_s3_uri, kms_key=kms_key, sagemaker_session=self.sagemaker_session, diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 93e3d91f87..06d2cde02e 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import copy +from textwrap import dedent import pytest from mock import Mock, patch, MagicMock @@ -1102,6 +1103,137 @@ def test_pyspark_processor_configuration_path_pipeline_config( ) +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_command(pipeline_session): + codeartifact_repo_arn = ( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ) + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + codeartifact_command = processor._get_codeartifact_command( + codeartifact_repo_arn=codeartifact_repo_arn + ) + + assert ( + codeartifact_command + == "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 # pylint: disable=line-too-long + ) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_get_codeartifact_command_bad_repo_arn(pipeline_session): + codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain" + + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + with pytest.raises(ValueError): + processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_generate_framework_script(pipeline_session): + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + framework_script = processor._generate_framework_script(user_script="process.py") + + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + echo 'CodeArtifact repository not specified. Skipping login.' + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ + ) + + +@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) +def test_generate_framework_script_with_codeartifact(pipeline_session): + processor = PyTorchProcessor( + role=ROLE, + instance_type="ml.m4.xlarge", + framework_version="2.0.1", + py_version="py310", + instance_count=1, + sagemaker_session=pipeline_session, + ) + + framework_script = processor._generate_framework_script( + user_script="process.py", + codeartifact_repo_arn=( + "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository" + ), + ) + + assert framework_script == dedent( + """\ + #!/bin/bash + + cd /opt/ml/processing/input/code/ + tar -xzf sourcedir.tar.gz + + # Exit on any error. SageMaker uses error code to mark failed job. + set -e + + if [[ -f 'requirements.txt' ]]; then + # Optionally log into CodeArtifact + if ! hash aws 2>/dev/null; then + echo "AWS CLI is not installed. Skipping CodeArtifact login." + else + aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2 + fi + + # Some py3 containers has typing, which may breaks pip install + pip uninstall --yes typing + + pip install -r requirements.txt + fi + + python process.py "$@" + """ # noqa: E501 # pylint: disable=line-too-long + ) + + def _get_script_processor(sagemaker_session): return ScriptProcessor( role=ROLE,