24
24
from textwrap import dedent
25
25
from typing import Dict , List , Optional , Union
26
26
from copy import copy
27
+ import re
27
28
28
29
import attr
29
30
@@ -1658,6 +1659,7 @@ def run( # type: ignore[override]
1658
1659
job_name : Optional [str ] = None ,
1659
1660
experiment_config : Optional [Dict [str , str ]] = None ,
1660
1661
kms_key : Optional [str ] = None ,
1662
+ codeartifact_repo_arn : Optional [str ] = None ,
1661
1663
):
1662
1664
"""Runs a processing job.
1663
1665
@@ -1758,12 +1760,21 @@ def run( # type: ignore[override]
1758
1760
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
1759
1761
kms_key (str): The ARN of the KMS key that is used to encrypt the
1760
1762
user code file (default: None).
1763
+ codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1764
+ logged into before installing dependencies (default: None).
1761
1765
Returns:
1762
1766
None or pipeline step arguments in case the Processor instance is built with
1763
1767
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
1764
1768
"""
1765
1769
s3_runproc_sh , inputs , job_name = self ._pack_and_upload_code (
1766
- code , source_dir , dependencies , git_config , job_name , inputs , kms_key
1770
+ code ,
1771
+ source_dir ,
1772
+ dependencies ,
1773
+ git_config ,
1774
+ job_name ,
1775
+ inputs ,
1776
+ kms_key ,
1777
+ codeartifact_repo_arn ,
1767
1778
)
1768
1779
1769
1780
# Submit a processing job.
@@ -1780,7 +1791,15 @@ def run( # type: ignore[override]
1780
1791
)
1781
1792
1782
1793
def _pack_and_upload_code (
1783
- self , code , source_dir , dependencies , git_config , job_name , inputs , kms_key = None
1794
+ self ,
1795
+ code ,
1796
+ source_dir ,
1797
+ dependencies ,
1798
+ git_config ,
1799
+ job_name ,
1800
+ inputs ,
1801
+ kms_key = None ,
1802
+ codeartifact_repo_arn = None ,
1784
1803
):
1785
1804
"""Pack local code bundle and upload to Amazon S3."""
1786
1805
if code .startswith ("s3://" ):
@@ -1821,12 +1840,53 @@ def _pack_and_upload_code(
1821
1840
script = estimator .uploaded_code .script_name
1822
1841
evaluated_kms_key = kms_key if kms_key else self .output_kms_key
1823
1842
s3_runproc_sh = self ._create_and_upload_runproc (
1824
- script , evaluated_kms_key , entrypoint_s3_uri
1843
+ script , evaluated_kms_key , entrypoint_s3_uri , codeartifact_repo_arn
1825
1844
)
1826
1845
1827
1846
return s3_runproc_sh , inputs , job_name
1828
1847
1829
- def _generate_framework_script (self , user_script : str ) -> str :
1848
+ def _get_codeartifact_command (self , codeartifact_repo_arn : str ) -> str :
1849
+ """Build an AWS CLI CodeArtifact command to configure pip.
1850
+
1851
+ The codeartifact_repo_arn property must follow the form
1852
+ # `arn:${Partition}:codeartifact:${Region}:${Account}:repository/${Domain}/${Repository}`
1853
+ https://docs.aws.amazon.com/codeartifact/latest/ug/python-configure-pip.html
1854
+ https://docs.aws.amazon.com/service-authorization/latest/reference/list_awscodeartifact.html#awscodeartifact-resources-for-iam-policies
1855
+
1856
+ Args:
1857
+ codeartifact_repo_arn: arn of the codeartifact repository
1858
+ Returns:
1859
+ codeartifact command string
1860
+ """
1861
+
1862
+ arn_regex = (
1863
+ "arn:(?P<partition>[^:]+):codeartifact:(?P<region>[^:]+):(?P<account>[^:]+)"
1864
+ ":repository/(?P<domain>[^/]+)/(?P<repository>.+)"
1865
+ )
1866
+ m = re .match (arn_regex , codeartifact_repo_arn )
1867
+ if not m :
1868
+ raise ValueError ("invalid CodeArtifact repository arn {}" .format (codeartifact_repo_arn ))
1869
+ domain = m .group ("domain" )
1870
+ owner = m .group ("account" )
1871
+ repository = m .group ("repository" )
1872
+ region = m .group ("region" )
1873
+
1874
+ logger .info (
1875
+ "configuring pip to use codeartifact "
1876
+ "(domain: %s, domain owner: %s, repository: %s, region: %s)" ,
1877
+ domain ,
1878
+ owner ,
1879
+ repository ,
1880
+ region ,
1881
+ )
1882
+
1883
+ return "aws codeartifact login --tool pip --domain {} --domain-owner {} --repository {} --region {}" .format ( # noqa: E501 pylint: disable=line-too-long
1884
+ domain , owner , repository , region
1885
+ )
1886
+
1887
+ def _generate_framework_script (
1888
+ self , user_script : str , codeartifact_repo_arn : str = None
1889
+ ) -> str :
1830
1890
"""Generate the framework entrypoint file (as text) for a processing job.
1831
1891
1832
1892
This script implements the "framework" functionality for setting up your code:
@@ -1837,7 +1897,16 @@ def _generate_framework_script(self, user_script: str) -> str:
1837
1897
Args:
1838
1898
user_script (str): Relative path to ```code``` in the source bundle
1839
1899
- e.g. 'process.py'.
1900
+ codeartifact_repo_arn (str): The ARN of the CodeArtifact repository that should be
1901
+ logged into before installing dependencies (default: None).
1840
1902
"""
1903
+ if codeartifact_repo_arn :
1904
+ codeartifact_login_command = self ._get_codeartifact_command (codeartifact_repo_arn )
1905
+ else :
1906
+ codeartifact_login_command = (
1907
+ "echo 'CodeArtifact repository not specified. Skipping login.'"
1908
+ )
1909
+
1841
1910
return dedent (
1842
1911
"""\
1843
1912
#!/bin/bash
@@ -1849,6 +1918,13 @@ def _generate_framework_script(self, user_script: str) -> str:
1849
1918
set -e
1850
1919
1851
1920
if [[ -f 'requirements.txt' ]]; then
1921
+ # Optionally log into CodeArtifact
1922
+ if ! hash aws 2>/dev/null; then
1923
+ echo "AWS CLI is not installed. Skipping CodeArtifact login."
1924
+ else
1925
+ {codeartifact_login_command}
1926
+ fi
1927
+
1852
1928
# Some py3 containers has typing, which may breaks pip install
1853
1929
pip uninstall --yes typing
1854
1930
@@ -1858,6 +1934,7 @@ def _generate_framework_script(self, user_script: str) -> str:
1858
1934
{entry_point_command} {entry_point} "$@"
1859
1935
"""
1860
1936
).format (
1937
+ codeartifact_login_command = codeartifact_login_command ,
1861
1938
entry_point_command = " " .join (self .command ),
1862
1939
entry_point = user_script ,
1863
1940
)
@@ -1933,7 +2010,9 @@ def _set_entrypoint(self, command, user_script_name):
1933
2010
)
1934
2011
self .entrypoint = self .framework_entrypoint_command + [user_script_location ]
1935
2012
1936
- def _create_and_upload_runproc (self , user_script , kms_key , entrypoint_s3_uri ):
2013
+ def _create_and_upload_runproc (
2014
+ self , user_script , kms_key , entrypoint_s3_uri , codeartifact_repo_arn = None
2015
+ ):
1937
2016
"""Create runproc shell script and upload to S3 bucket.
1938
2017
1939
2018
If leveraging a pipeline session with optimized S3 artifact paths,
@@ -1949,7 +2028,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
1949
2028
from sagemaker .workflow .utilities import _pipeline_config , hash_object
1950
2029
1951
2030
if _pipeline_config and _pipeline_config .pipeline_name :
1952
- runproc_file_str = self ._generate_framework_script (user_script )
2031
+ runproc_file_str = self ._generate_framework_script (user_script , codeartifact_repo_arn )
1953
2032
runproc_file_hash = hash_object (runproc_file_str )
1954
2033
s3_uri = s3 .s3_path_join (
1955
2034
"s3://" ,
@@ -1968,7 +2047,7 @@ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
1968
2047
)
1969
2048
else :
1970
2049
s3_runproc_sh = S3Uploader .upload_string_as_file_body (
1971
- self ._generate_framework_script (user_script ),
2050
+ self ._generate_framework_script (user_script , codeartifact_repo_arn ),
1972
2051
desired_s3_uri = entrypoint_s3_uri ,
1973
2052
kms_key = kms_key ,
1974
2053
sagemaker_session = self .sagemaker_session ,
0 commit comments