Skip to content

Commit 502e051

Browse files
akuma12sage-maker
andauthored
feature: Add optional CodeArtifact login to FrameworkProcessing job script (#4145)
* feature: Add optional CodeArtifact login to FrameworkProcessing job script * Add unit test for _get_codeartifact_index * Fixed docstring * Convert CodeArtifact integration to simply generate an AWS CLI command to log into CodeArtifact * Fix lint issues * More lint fixes * Lint fix * Yet Another Lint Fix * Black fix --------- Co-authored-by: sage-maker <[email protected]>
1 parent 52934e2 commit 502e051

File tree

2 files changed

+218
-7
lines changed

2 files changed

+218
-7
lines changed

src/sagemaker/processing.py

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from textwrap import dedent
2525
from typing import Dict, List, Optional, Union
2626
from copy import copy
27+
import re
2728

2829
import attr
2930

@@ -1658,6 +1659,7 @@ def run( # type: ignore[override]
16581659
job_name: Optional[str] = None,
16591660
experiment_config: Optional[Dict[str, str]] = None,
16601661
kms_key: Optional[str] = None,
1662+
codeartifact_repo_arn: Optional[str] = None,
16611663
):
16621664
"""Runs a processing job.
16631665
@@ -1758,12 +1760,21 @@ def run( # type: ignore[override]
17581760
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
17591761
kms_key (str): The ARN of the KMS key that is used to encrypt the
17601762
user code file (default: None).
1763+
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1764+
logged into before installing dependencies (default: None).
17611765
Returns:
17621766
None or pipeline step arguments in case the Processor instance is built with
17631767
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
17641768
"""
17651769
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1766-
code, source_dir, dependencies, git_config, job_name, inputs, kms_key
1770+
code,
1771+
source_dir,
1772+
dependencies,
1773+
git_config,
1774+
job_name,
1775+
inputs,
1776+
kms_key,
1777+
codeartifact_repo_arn,
17671778
)
17681779

17691780
# Submit a processing job.
@@ -1780,7 +1791,15 @@ def run( # type: ignore[override]
17801791
)
17811792

17821793
def _pack_and_upload_code(
1783-
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
1794+
self,
1795+
code,
1796+
source_dir,
1797+
dependencies,
1798+
git_config,
1799+
job_name,
1800+
inputs,
1801+
kms_key=None,
1802+
codeartifact_repo_arn=None,
17841803
):
17851804
"""Pack local code bundle and upload to Amazon S3."""
17861805
if code.startswith("s3://"):
@@ -1821,12 +1840,53 @@ def _pack_and_upload_code(
18211840
script = estimator.uploaded_code.script_name
18221841
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
18231842
s3_runproc_sh = self._create_and_upload_runproc(
1824-
script, evaluated_kms_key, entrypoint_s3_uri
1843+
script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn
18251844
)
18261845

18271846
return s3_runproc_sh, inputs, job_name
18281847

1829-
def _generate_framework_script(self, user_script: str) -> str:
1848+
def _get_codeartifact_command(self, codeartifact_repo_arn: str) -> str:
1849+
"""Build an AWS CLI CodeArtifact command to configure pip.
1850+
1851+
The codeartifact_repo_arn property must follow the form
1852+
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
1853+
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
1854+
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1855+
1856+
Args:
1857+
codeartifact_repo_arn: arn of the codeartifact repository
1858+
Returns:
1859+
codeartifact command string
1860+
"""
1861+
1862+
arn_regex = (
1863+
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
1864+
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
1865+
)
1866+
m = re.match(arn_regex, codeartifact_repo_arn)
1867+
if not m:
1868+
raise ValueError("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
1869+
domain = m.group("domain")
1870+
owner = m.group("account")
1871+
repository = m.group("repository")
1872+
region = m.group("region")
1873+
1874+
logger.info(
1875+
"configuring pip to use codeartifact "
1876+
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
1877+
domain,
1878+
owner,
1879+
repository,
1880+
region,
1881+
)
1882+
1883+
return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}".format( # noqa: E501 pylint: disable=line-too-long
1884+
domain, owner, repository, region
1885+
)
1886+
1887+
def _generate_framework_script(
1888+
self, user_script: str, codeartifact_repo_arn: str = None
1889+
) -> str:
18301890
"""Generate the framework entrypoint file (as text) for a processing job.
18311891
18321892
This script implements the "framework" functionality for setting up your code:
@@ -1837,7 +1897,16 @@ def _generate_framework_script(self, user_script: str) -> str:
18371897
Args:
18381898
user_script (str): Relative path to ```code``` in the source bundle
18391899
- e.g. 'process.py'.
1900+
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1901+
logged into before installing dependencies (default: None).
18401902
"""
1903+
if codeartifact_repo_arn:
1904+
codeartifact_login_command = self._get_codeartifact_command(codeartifact_repo_arn)
1905+
else:
1906+
codeartifact_login_command = (
1907+
"echo 'CodeArtifact repository not specified. Skipping login.'"
1908+
)
1909+
18411910
return dedent(
18421911
"""\
18431912
#!/bin/bash
@@ -1849,6 +1918,13 @@ def _generate_framework_script(self, user_script: str) -> str:
18491918
set -e
18501919
18511920
if [[ -f 'requirements.txt' ]]; then
1921+
# Optionally log into CodeArtifact
1922+
if ! hash aws 2>/dev/null; then
1923+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1924+
else
1925+
{codeartifact_login_command}
1926+
fi
1927+
18521928
# Some py3 containers has typing, which may breaks pip install
18531929
pip uninstall --yes typing
18541930
@@ -1858,6 +1934,7 @@ def _generate_framework_script(self, user_script: str) -> str:
18581934
{entry_point_command} {entry_point} "$@"
18591935
"""
18601936
).format(
1937+
codeartifact_login_command=codeartifact_login_command,
18611938
entry_point_command=" ".join(self.command),
18621939
entry_point=user_script,
18631940
)
@@ -1933,7 +2010,9 @@ def _set_entrypoint(self, command, user_script_name):
19332010
)
19342011
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
19352012

1936-
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
2013+
def _create_and_upload_runproc(
2014+
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None
2015+
):
19372016
"""Create runproc shell script and upload to S3 bucket.
19382017
19392018
If leveraging a pipeline session with optimized S3 artifact paths,
@@ -1949,7 +2028,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19492028
from sagemaker.workflow.utilities import _pipeline_config, hash_object
19502029

19512030
if _pipeline_config and _pipeline_config.pipeline_name:
1952-
runproc_file_str = self._generate_framework_script(user_script)
2031+
runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn)
19532032
runproc_file_hash = hash_object(runproc_file_str)
19542033
s3_uri = s3.s3_path_join(
19552034
"s3://",
@@ -1968,7 +2047,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19682047
)
19692048
else:
19702049
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1971-
self._generate_framework_script(user_script),
2050+
self._generate_framework_script(user_script, codeartifact_repo_arn),
19722051
desired_s3_uri=entrypoint_s3_uri,
19732052
kms_key=kms_key,
19742053
sagemaker_session=self.sagemaker_session,

tests/unit/test_processing.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import copy
16+
from textwrap import dedent
1617

1718
import pytest
1819
from mock import Mock, patch, MagicMock
@@ -1102,6 +1103,137 @@ def test_pyspark_processor_configuration_path_pipeline_config(
11021103
)
11031104

11041105

1106+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1107+
def test_get_codeartifact_command(pipeline_session):
1108+
codeartifact_repo_arn = (
1109+
"arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1110+
)
1111+
1112+
processor = PyTorchProcessor(
1113+
role=ROLE,
1114+
instance_type="ml.m4.xlarge",
1115+
framework_version="2.0.1",
1116+
py_version="py310",
1117+
instance_count=1,
1118+
sagemaker_session=pipeline_session,
1119+
)
1120+
1121+
codeartifact_command = processor._get_codeartifact_command(
1122+
codeartifact_repo_arn=codeartifact_repo_arn
1123+
)
1124+
1125+
assert (
1126+
codeartifact_command
1127+
== "aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2" # noqa: E501 # pylint: disable=line-too-long
1128+
)
1129+
1130+
1131+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1132+
def test_get_codeartifact_command_bad_repo_arn(pipeline_session):
1133+
codeartifact_repo_arn = "arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain"
1134+
1135+
processor = PyTorchProcessor(
1136+
role=ROLE,
1137+
instance_type="ml.m4.xlarge",
1138+
framework_version="2.0.1",
1139+
py_version="py310",
1140+
instance_count=1,
1141+
sagemaker_session=pipeline_session,
1142+
)
1143+
1144+
with pytest.raises(ValueError):
1145+
processor._get_codeartifact_command(codeartifact_repo_arn=codeartifact_repo_arn)
1146+
1147+
1148+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1149+
def test_generate_framework_script(pipeline_session):
1150+
processor = PyTorchProcessor(
1151+
role=ROLE,
1152+
instance_type="ml.m4.xlarge",
1153+
framework_version="2.0.1",
1154+
py_version="py310",
1155+
instance_count=1,
1156+
sagemaker_session=pipeline_session,
1157+
)
1158+
1159+
framework_script = processor._generate_framework_script(user_script="process.py")
1160+
1161+
assert framework_script == dedent(
1162+
"""\
1163+
#!/bin/bash
1164+
1165+
cd /opt/ml/processing/input/code/
1166+
tar -xzf sourcedir.tar.gz
1167+
1168+
# Exit on any error. SageMaker uses error code to mark failed job.
1169+
set -e
1170+
1171+
if [[ -f 'requirements.txt' ]]; then
1172+
# Optionally log into CodeArtifact
1173+
if ! hash aws 2>/dev/null; then
1174+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1175+
else
1176+
echo 'CodeArtifact repository not specified. Skipping login.'
1177+
fi
1178+
1179+
# Some py3 containers has typing, which may breaks pip install
1180+
pip uninstall --yes typing
1181+
1182+
pip install -r requirements.txt
1183+
fi
1184+
1185+
python process.py "$@"
1186+
"""
1187+
)
1188+
1189+
1190+
@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
1191+
def test_generate_framework_script_with_codeartifact(pipeline_session):
1192+
processor = PyTorchProcessor(
1193+
role=ROLE,
1194+
instance_type="ml.m4.xlarge",
1195+
framework_version="2.0.1",
1196+
py_version="py310",
1197+
instance_count=1,
1198+
sagemaker_session=pipeline_session,
1199+
)
1200+
1201+
framework_script = processor._generate_framework_script(
1202+
user_script="process.py",
1203+
codeartifact_repo_arn=(
1204+
"arn:aws:codeartifact:us-west-2:012345678901:repository/test-domain/test-repository"
1205+
),
1206+
)
1207+
1208+
assert framework_script == dedent(
1209+
"""\
1210+
#!/bin/bash
1211+
1212+
cd /opt/ml/processing/input/code/
1213+
tar -xzf sourcedir.tar.gz
1214+
1215+
# Exit on any error. SageMaker uses error code to mark failed job.
1216+
set -e
1217+
1218+
if [[ -f 'requirements.txt' ]]; then
1219+
# Optionally log into CodeArtifact
1220+
if ! hash aws 2>/dev/null; then
1221+
echo "AWS CLI is not installed. Skipping CodeArtifact login."
1222+
else
1223+
aws codeartifact login --tool pip --domain test-domain --domain-owner 012345678901 --repository test-repository --region us-west-2
1224+
fi
1225+
1226+
# Some py3 containers has typing, which may breaks pip install
1227+
pip uninstall --yes typing
1228+
1229+
pip install -r requirements.txt
1230+
fi
1231+
1232+
python process.py "$@"
1233+
""" # noqa: E501 # pylint: disable=line-too-long
1234+
)
1235+
1236+
11051237
def _get_script_processor(sagemaker_session):
11061238
return ScriptProcessor(
11071239
role=ROLE,

0 commit comments

Comments
 (0)