Skip to content

Commit 61190de

Browse files
akuma12goelakash
authored andcommitted
feature: Add optional CodeArtifact login to FrameworkProcessing job script
1 parent a9ac311 commit 61190de

File tree

1 file changed

+91
-8
lines changed

1 file changed

+91
-8
lines changed

src/sagemaker/processing.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from textwrap import dedent
2525
from typing import Dict, List, Optional, Union
2626
from copy import copy
27+
import re
2728

2829
import attr
2930

@@ -1659,6 +1660,7 @@ def run( # type: ignore[override]
16591660
job_name: Optional[str] = None,
16601661
experiment_config: Optional[Dict[str, str]] = None,
16611662
kms_key: Optional[str] = None,
1663+
codeartifact_repo_arn: Optional[str] = None,
16621664
):
16631665
"""Runs a processing job.
16641666
@@ -1759,12 +1761,21 @@ def run( # type: ignore[override]
17591761
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
17601762
kms_key (str): The ARN of the KMS key that is used to encrypt the
17611763
user code file (default: None).
1764+
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1765+
logged into before installing dependencies (default: None).
17621766
Returns:
17631767
None or pipeline step arguments in case the Processor instance is built with
17641768
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
17651769
"""
17661770
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1767-
code, source_dir, dependencies, git_config, job_name, inputs, kms_key
1771+
code,
1772+
source_dir,
1773+
dependencies,
1774+
git_config,
1775+
job_name,
1776+
inputs,
1777+
kms_key,
1778+
codeartifact_repo_arn,
17681779
)
17691780

17701781
# Submit a processing job.
@@ -1781,7 +1792,15 @@ def run( # type: ignore[override]
17811792
)
17821793

17831794
def _pack_and_upload_code(
1784-
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
1795+
self,
1796+
code,
1797+
source_dir,
1798+
dependencies,
1799+
git_config,
1800+
job_name,
1801+
inputs,
1802+
kms_key=None,
1803+
codeartifact_repo_arn=None,
17851804
):
17861805
"""Pack local code bundle and upload to Amazon S3."""
17871806
if code.startswith("s3://"):
@@ -1822,12 +1841,65 @@ def _pack_and_upload_code(
18221841
script = estimator.uploaded_code.script_name
18231842
evaluated_kms_key = kms_key if kms_key else self.output_kms_key
18241843
s3_runproc_sh = self._create_and_upload_runproc(
1825-
script, evaluated_kms_key, entrypoint_s3_uri
1844+
script, evaluated_kms_key, entrypoint_s3_uri, codeartifact_repo_arn
18261845
)
18271846

18281847
return s3_runproc_sh, inputs, job_name
18291848

1830-
def _generate_framework_script(self, user_script: str) -> str:
1849+
def _get_codeartifact_index(self, codeartifact_repo_arn: str):
1850+
"""
1851+
Build the authenticated codeartifact index url based on the arn provided
1852+
via codeartifact_repo_arn property following the form
1853+
# `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
1854+
https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
1855+
https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1856+
:return: authenticated codeartifact index url
1857+
"""
1858+
1859+
arn_regex = (
1860+
"arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
1861+
":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
1862+
)
1863+
m = re.match(arn_regex, codeartifact_repo_arn)
1864+
if not m:
1865+
raise Exception("invalid CodeArtifact repository arn {}".format(codeartifact_repo_arn))
1866+
domain = m.group("domain")
1867+
owner = m.group("account")
1868+
repository = m.group("repository")
1869+
region = m.group("region")
1870+
1871+
logger.info(
1872+
"configuring pip to use codeartifact "
1873+
"(domain: %s, domain owner: %s, repository: %s, region: %s)",
1874+
domain,
1875+
owner,
1876+
repository,
1877+
region,
1878+
)
1879+
try:
1880+
client = self.sagemaker_session.boto_session.client("codeartifact", region_name=region)
1881+
auth_token_response = client.get_authorization_token(domain=domain, domainOwner=owner)
1882+
token = auth_token_response["authorizationToken"]
1883+
endpoint_response = client.get_repository_endpoint(
1884+
domain=domain, domainOwner=owner, repository=repository, format="pypi"
1885+
)
1886+
unauthenticated_index = endpoint_response["repositoryEndpoint"]
1887+
return re.sub(
1888+
"https://",
1889+
"https://aws:{}@".format(token),
1890+
re.sub(
1891+
"{}/?$".format(repository),
1892+
"{}/simple/".format(repository),
1893+
unauthenticated_index,
1894+
),
1895+
)
1896+
except Exception:
1897+
logger.error("failed to configure pip to use codeartifact")
1898+
raise Exception("failed to configure pip to use codeartifact")
1899+
1900+
def _generate_framework_script(
1901+
self, user_script: str, codeartifact_repo_arn: str = None
1902+
) -> str:
18311903
"""Generate the framework entrypoint file (as text) for a processing job.
18321904
18331905
This script implements the "framework" functionality for setting up your code:
@@ -1838,7 +1910,15 @@ def _generate_framework_script(self, user_script: str) -> str:
18381910
Args:
18391911
user_script (str): Relative path to ```code``` in the source bundle
18401912
- e.g. 'process.py'.
1913+
codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1914+
logged into before installing dependencies (default: None).
18411915
"""
1916+
if codeartifact_repo_arn:
1917+
index = self._get_codeartifact_index(codeartifact_repo_arn)
1918+
index_option = "-i {}".format(index)
1919+
else:
1920+
index_option = ""
1921+
18421922
return dedent(
18431923
"""\
18441924
#!/bin/bash
@@ -1853,12 +1933,13 @@ def _generate_framework_script(self, user_script: str) -> str:
18531933
# Some py3 containers has typing, which may breaks pip install
18541934
pip uninstall --yes typing
18551935
1856-
pip install -r requirements.txt
1936+
pip install -r requirements.txt {index_option}
18571937
fi
18581938
18591939
{entry_point_command} {entry_point} "$@"
18601940
"""
18611941
).format(
1942+
index_option=index_option,
18621943
entry_point_command=" ".join(self.command),
18631944
entry_point=user_script,
18641945
)
@@ -1934,7 +2015,9 @@ def _set_entrypoint(self, command, user_script_name):
19342015
)
19352016
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
19362017

1937-
def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
2018+
def _create_and_upload_runproc(
2019+
self, user_script, kms_key, entrypoint_s3_uri, codeartifact_repo_arn=None
2020+
):
19382021
"""Create runproc shell script and upload to S3 bucket.
19392022
19402023
If leveraging a pipeline session with optimized S3 artifact paths,
@@ -1950,7 +2033,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19502033
from sagemaker.workflow.utilities import _pipeline_config, hash_object
19512034

19522035
if _pipeline_config and _pipeline_config.pipeline_name:
1953-
runproc_file_str = self._generate_framework_script(user_script)
2036+
runproc_file_str = self._generate_framework_script(user_script, codeartifact_repo_arn)
19542037
runproc_file_hash = hash_object(runproc_file_str)
19552038
s3_uri = s3.s3_path_join(
19562039
"s3://",
@@ -1969,7 +2052,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
19692052
)
19702053
else:
19712054
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1972-
self._generate_framework_script(user_script),
2055+
self._generate_framework_script(user_script, codeartifact_repo_arn),
19732056
desired_s3_uri=entrypoint_s3_uri,
19742057
kms_key=kms_key,
19752058
sagemaker_session=self.sagemaker_session,

0 commit comments

Comments
 (0)