diff --git a/doc/overview.rst b/doc/overview.rst index d5bef646f5..7d39f746fb 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -183,38 +183,43 @@ Here is an example: # When you are done using your endpoint algo.delete_endpoint() -Git Support ------------ -If you have your training scripts in your GitHub repository, you can use them directly without the trouble to download -them to local machine. Git support can be enabled simply by providing ``git_config`` parameter when initializing an -estimator. If Git support is enabled, then ``entry_point``, ``source_dir`` and ``dependencies`` should all be relative -paths in the Git repo. Note that if you decided to use Git support, then everything you need for ``entry_point``, -``source_dir`` and ``dependencies`` should be in a single Git repo. +Use Scripts Stored in a Git Repository +-------------------------------------- +When you create an estimator, you can specify a training script that is stored in a GitHub or other Git repository as the entry point for the estimator, so that you don't have to download the scripts locally. +If you do so, source directory and dependencies should be in the same repo if they are needed. Git support can be enabled simply by providing ``git_config`` parameter +when creating an ``Estimator`` object. If Git support is enabled, then ``entry_point``, ``source_dir`` and ``dependencies`` +should be relative paths in the Git repo if provided. -Here are ways to specify ``git_config``: +The ``git_config`` parameter includes fields ``repo``, ``branch``, ``commit``, ``2FA_enabled``, ``username``, +``password`` and ``token``. The ``repo`` field is required. All other fields are optional. ``repo`` specifies the Git +repository where your training script is stored. If you don't provide ``branch``, the default value 'master' is used. +If you don't provide ``commit``, the latest commit in the specified branch is used. -.. code:: python +``2FA_enabled``, ``username``, ``password`` and ``token`` are used for authentication. Set ``2FA_enabled`` to 'True' if +two-factor authentication is enabled for the GitHub (or other Git) account, otherwise set it to 'False'. +If you do not provide a value for ``2FA_enabled``, a default value of 'False' is used. - # Specifies the git_config parameter - git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', - 'branch': 'branch1', - 'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'} - - # Alternatively, you can also specify git_config by providing only 'repo' and 'branch'. - # If this is the case, the latest commit in the branch will be used. - git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', - 'branch': 'branch1'} +If ``repo`` is an SSH URL, you should either have no passphrase for the SSH key pairs, or have the ``ssh-agent`` configured +so that you are not prompted for the SSH passphrase when you run a ``git clone`` command with SSH URLs. For SSH URLs, it +does not matter whether two-factor authentication is enabled. - # Only providing 'repo' is also allowed. If this is the case, latest commit in - # 'master' branch will be used. - git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git'} +If ``repo`` is an https URL, 2FA matters. When 2FA is disabled, either ``token`` or ``username``+``password`` will be +used for authentication if provided (``token`` prioritized). When 2FA is enabled, only token will be used for +authentication if provided. If required authentication info is not provided, python SDK will try to use local +credentials storage to authenticate. If that fails either, an error message will be thrown. -The following are some examples to define estimators with Git support: +Here are some examples of creating estimators with Git support: .. code:: python + # Specifies the git_config parameter. This example does not provide Git credentials, so python SDK will try + # to use local credential storage. + git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', + 'branch': 'branch1', + 'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'} + # In this example, the source directory 'pytorch' contains the entry point 'mnist.py' and other source code. - # and it is relative path inside the Git repo. + # and it is relative path inside the Git repo. pytorch_estimator = PyTorch(entry_point='mnist.py', role='SageMakerRole', source_dir='pytorch', @@ -222,6 +227,13 @@ The following are some examples to define estimators with Git support: train_instance_count=1, train_instance_type='ml.c4.xlarge') +.. code:: python + + # You can also specify git_config by providing only 'repo' and 'branch'. + # If this is the case, the latest commit in that branch will be used. + git_config = {'repo': 'git@github.com:username/repo-with-training-scripts.git', + 'branch': 'branch1'} + # In this example, the entry point 'mnist.py' is all we need for source code. # We need to specify the path to it in the Git repo. mx_estimator = MXNet(entry_point='mxnet/mnist.py', @@ -230,6 +242,15 @@ The following are some examples to define estimators with Git support: train_instance_count=1, train_instance_type='ml.c4.xlarge') +.. code:: python + + # Only providing 'repo' is also allowed. If this is the case, latest commit in 'master' branch will be used. + # This example does not provide '2FA_enabled', so 2FA is treated as disabled by default. 'username' and + # 'password' are provided for authentication + git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', + 'username': 'username', + 'password': 'passw0rd!'} + # In this example, besides entry point and other source code in source directory, we still need some # dependencies for the training job. Dependencies should also be paths inside the Git repo. pytorch_estimator = PyTorch(entry_point='mnist.py', @@ -240,7 +261,23 @@ The following are some examples to define estimators with Git support: train_instance_count=1, train_instance_type='ml.c4.xlarge') -When Git support is enabled, users can still use local mode in the same way. +.. code:: python + + # This example specifies that 2FA is enabled, and token is provided for authentication + git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', + '2FA_enabled': True, + 'token': 'your-token'} + + # In this exmaple, besides entry point, we also need some dependencies for the training job. + pytorch_estimator = PyTorch(entry_point='pytorch/mnist.py', + role='SageMakerRole', + dependencies=['dep.py'], + git_config=git_config, + train_instance_count=1, + train_instance_type='local') + +Git support can be used not only for training jobs, but also for hosting models. The usage is the same as the above, +and ``git_config`` should be provided when creating model objects, e.g. ``TensorFlowModel``, ``MXNetModel``, ``PyTorchModel``. Training Metrics ---------------- diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 287d203912..4a3aeb9649 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -20,10 +20,10 @@ from abc import abstractmethod from six import with_metaclass from six import string_types - import sagemaker from sagemaker import git_utils from sagemaker.analytics import TrainingJobAnalytics + from sagemaker.fw_utils import ( create_image_uri, tar_and_upload_dir, @@ -975,10 +975,12 @@ def __init__( >>> |----- test.py You can assign entry_point='src/train.py'. - git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch' - and 'commit' (default: None). - 'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If - 'commit' is not specified, the latest commit in the required branch will be used. + git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``, + ``commit``, ``2FA_enabled``, ``username``, ``password`` and ``token`` (default: None). The fields are + optional except ``repo``. If ``branch`` is not specified, master branch will be used. If ``commit`` + is not specified, the latest commit in the required branch will be used. 'branch' and 'commit' are + optional. If 'branch' is not specified, 'master' branch will be used. If 'commit' is not specified, + the latest commit in the required branch will be used. Example: The following config: @@ -989,6 +991,15 @@ def __init__( results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout the specified commit. + ``2FA_enabled``, ``username``, ``password`` and ``token`` are for authentication purpose. + ``2FA_enabled`` must be ``True`` or ``False`` if it is provided. If ``2FA_enabled`` is not provided, + we consider 2FA as disabled. For GitHub and other Git repos, when ssh urls are provided, it does not + make a difference whether 2FA is enabled or disabled; an ssh passphrase should be in local storage. + When https urls are provided: if 2FA is disabled, then either token or username+password will + be used for authentication if provided (token prioritized); if 2FA is enabled, only token will + be used for authentication if provided. If required authentication info is not provided, python SDK + will try to use local credentials storage to authenticate. If that fails either, an error message will + be thrown. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. If 'git_config' is provided, diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 7360c31e55..fa4e104cfc 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -16,6 +16,8 @@ import six import subprocess import tempfile +import warnings +from six.moves import urllib def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): @@ -23,9 +25,18 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): and set ``entry_point``, ``source_dir`` and ``dependencies`` to the right file or directory in the repo cloned. Args: - git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` - and ``commit``. ``branch`` and ``commit`` are optional. If ``branch`` is not specified, master branch - will be used. If ``commit`` is not specified, the latest commit in the required branch will be used. + git_config (dict[str, object]): Git configurations used for cloning files, including ``repo``, ``branch``, + ``commit``, ``2FA_enabled``, ``username``, ``password`` and ``token``. The fields are optional except + ``repo``. If ``branch`` is not specified, master branch will be used. If ``commit`` is not specified, + the latest commit in the required branch will be used. ``2FA_enabled``, ``username``, ``password`` and + ``token`` are for authentication purpose. + ``2FA_enabled`` must be ``True`` or ``False`` if it is provided. If ``2FA_enabled`` is not provided, we + consider 2FA as disabled. For GitHub and other Git repos, when ssh urls are provided, it does not make a + difference whether 2FA is enabled or disabled; an ssh passphrase should be in local storage. When + https urls are provided: if 2FA is disabled, then either token or username+password will be used for + authentication if provided (token prioritized); if 2FA is enabled, only token will be used for + authentication if provided. If required authentication info is not provided, python SDK will try to use + local credentials storage to authenticate. If that fails either, an error message will be thrown. entry_point (str): A relative location to the Python source file which should be executed as the entry point to training or model hosting in the Git repo. source_dir (str): A relative location to a directory with other training or model hosting source code @@ -41,18 +52,18 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): ValueError: If 1. entry point specified does not exist in the repo 2. source dir specified does not exist in the repo 3. dependencies specified do not exist in the repo - 4. git_config is in bad format + 4. wrong format is provided for git_config Returns: - dict: A dict that contains the updated values of entry_point, source_dir and dependencies + dict: A dict that contains the updated values of entry_point, source_dir and dependencies. """ if entry_point is None: raise ValueError("Please provide an entry point.") _validate_git_config(git_config) - repo_dir = tempfile.mkdtemp() - subprocess.check_call(["git", "clone", git_config["repo"], repo_dir]) + dest_dir = tempfile.mkdtemp() + _generate_and_run_clone_command(git_config, dest_dir) - _checkout_branch_and_commit(git_config, repo_dir) + _checkout_branch_and_commit(git_config, dest_dir) updated_paths = { "entry_point": entry_point, @@ -62,62 +73,180 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): # check if the cloned repo contains entry point, source directory and dependencies if source_dir: - if not os.path.isdir(os.path.join(repo_dir, source_dir)): + if not os.path.isdir(os.path.join(dest_dir, source_dir)): raise ValueError("Source directory does not exist in the repo.") - if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)): + if not os.path.isfile(os.path.join(dest_dir, source_dir, entry_point)): raise ValueError("Entry point does not exist in the repo.") - updated_paths["source_dir"] = os.path.join(repo_dir, source_dir) + updated_paths["source_dir"] = os.path.join(dest_dir, source_dir) else: - if os.path.isfile(os.path.join(repo_dir, entry_point)): - updated_paths["entry_point"] = os.path.join(repo_dir, entry_point) + if os.path.isfile(os.path.join(dest_dir, entry_point)): + updated_paths["entry_point"] = os.path.join(dest_dir, entry_point) else: raise ValueError("Entry point does not exist in the repo.") - - updated_paths["dependencies"] = [] - for path in dependencies: - if os.path.exists(os.path.join(repo_dir, path)): - updated_paths["dependencies"].append(os.path.join(repo_dir, path)) - else: - raise ValueError("Dependency {} does not exist in the repo.".format(path)) + if dependencies is not None: + updated_paths["dependencies"] = [] + for path in dependencies: + if os.path.exists(os.path.join(dest_dir, path)): + updated_paths["dependencies"].append(os.path.join(dest_dir, path)) + else: + raise ValueError("Dependency {} does not exist in the repo.".format(path)) return updated_paths def _validate_git_config(git_config): - """check if a git_config param is valid + if "repo" not in git_config: + raise ValueError("Please provide a repo for git_config.") + for key in git_config: + if key == "2FA_enabled": + if not isinstance(git_config["2FA_enabled"], bool): + raise ValueError("Please enter a bool type for 2FA_enabled'.") + elif not isinstance(git_config[key], six.string_types): + raise ValueError("'{}' must be a string.".format(key)) + + +def _generate_and_run_clone_command(git_config, dest_dir): + """check if a git_config param is valid, if it is, create the command to git clone the repo, and run it. Args: git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` and ``commit``. + dest_dir (str): The local directory to clone the Git repo into. Raises: - ValueError: If: - 1. git_config has no key 'repo' - 2. git_config['repo'] is in the wrong format. + CalledProcessError: If failed to clone git repo. """ - if "repo" not in git_config: - raise ValueError("Please provide a repo for git_config.") - allowed_keys = ["repo", "branch", "commit"] - for key in allowed_keys: - if key in git_config and not isinstance(git_config[key], six.string_types): - raise ValueError("'{}' should be a string".format(key)) - for key in git_config: - if key not in allowed_keys: - raise ValueError("Unexpected argument(s) provided for git_config!") + _clone_command_for_github_like(git_config, dest_dir) + + +def _clone_command_for_github_like(git_config, dest_dir): + """check if a git_config param representing a GitHub (or like) repo is valid, if it is, create the command to + git clone the repo, and run it. + + Args: + git_config ((dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` + and ``commit``. + dest_dir (str): The local directory to clone the Git repo into. + + Raises: + ValueError: If git_config['repo'] is in the wrong format. + CalledProcessError: If failed to clone git repo. + """ + is_https = git_config["repo"].startswith("https://") + is_ssh = git_config["repo"].startswith("git@") + if not is_https and not is_ssh: + raise ValueError("Invalid Git url provided.") + if is_ssh: + _clone_command_for_github_like_ssh(git_config, dest_dir) + elif "2FA_enabled" in git_config and git_config["2FA_enabled"] is True: + _clone_command_for_github_like_https_2fa_enabled(git_config, dest_dir) + else: + _clone_command_for_github_like_https_2fa_disabled(git_config, dest_dir) + + +def _clone_command_for_github_like_ssh(git_config, dest_dir): + if "username" in git_config or "password" in git_config or "token" in git_config: + warnings.warn("SSH cloning, authentication information in git config will be ignored.") + _run_clone_command(git_config["repo"], dest_dir) + +def _clone_command_for_github_like_https_2fa_disabled(git_config, dest_dir): + updated_url = git_config["repo"] + if "token" in git_config: + if "username" in git_config or "password" in git_config: + warnings.warn("Using token for authentication, " "other credentials will be ignored.") + updated_url = _insert_token_to_repo_url(url=git_config["repo"], token=git_config["token"]) + elif "username" in git_config and "password" in git_config: + updated_url = _insert_username_and_password_to_repo_url( + url=git_config["repo"], username=git_config["username"], password=git_config["password"] + ) + elif "username" in git_config or "password" in git_config: + warnings.warn("Credentials provided in git config will be ignored.") + _run_clone_command(updated_url, dest_dir) -def _checkout_branch_and_commit(git_config, repo_dir): + +def _clone_command_for_github_like_https_2fa_enabled(git_config, dest_dir): + updated_url = git_config["repo"] + if "token" in git_config: + if "username" in git_config or "password" in git_config: + warnings.warn("Using token for authentication, " "other credentials will be ignored.") + updated_url = _insert_token_to_repo_url(url=git_config["repo"], token=git_config["token"]) + _run_clone_command(updated_url, dest_dir) + + +def _run_clone_command(repo_url, dest_dir): + """Run the 'git clone' command with the repo url and the directory to clone the repo into. + + Args: + repo_url (str): Git repo url to be cloned. + dest_dir: (str): Local path where the repo should be cloned into. + + Raises: + CalledProcessError: If failed to clone git repo. + """ + my_env = os.environ.copy() + if repo_url.startswith("https://"): + my_env["GIT_TERMINAL_PROMPT"] = "0" + subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) + elif repo_url.startswith("git@"): + with tempfile.NamedTemporaryFile() as sshnoprompt: + write_pipe = open(sshnoprompt.name, "w") + write_pipe.write("ssh -oBatchMode=yes $@") + write_pipe.close() + # 511 in decimal is same as 777 in octal + os.chmod(sshnoprompt.name, 511) + my_env["GIT_SSH"] = sshnoprompt.name + subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) + + +def _insert_token_to_repo_url(url, token): + """Insert the token to the Git repo url, to make a component of the git clone command. This method can + only be called when repo_url is an https url. + + Args: + url (str): Git repo url where the token should be inserted into. + token (str): Token to be inserted. + + Returns: + str: the component needed fot the git clone command. + """ + index = len("https://") + if url.find(token) == index: + return url + return url.replace("https://", "https://" + token + "@") + + +def _insert_username_and_password_to_repo_url(url, username, password): + """Insert the username and the password to the Git repo url, to make a component of the git clone command. + This method can only be called when repo_url is an https url. + + Args: + url (str): Git repo url where the token should be inserted into. + username (str): Username to be inserted. + password (str): Password to be inserted. + + Returns: + str: the component needed for the git clone command. + """ + password = urllib.parse.quote_plus(password) + # urllib parses ' ' as '+', but what we need is '%20' here + password = password.replace("+", "%20") + index = len("https://") + return url[:index] + username + ":" + password + "@" + url[index:] + + +def _checkout_branch_and_commit(git_config, dest_dir): """Checkout the required branch and commit. Args: - git_config: (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` + git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch`` and ``commit``. - repo_dir (str): the directory where the repo is cloned + dest_dir (str): the directory where the repo is cloned Raises: CalledProcessError: If 1. failed to checkout the required branch 2. failed to checkout the required commit """ if "branch" in git_config: - subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir)) + subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(dest_dir)) if "commit" in git_config: - subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(repo_dir)) + subprocess.check_call(args=["git", "checkout", git_config["commit"]], cwd=str(dest_dir)) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index e2352c5f41..b49ab5e7f5 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -516,10 +516,12 @@ def __init__( >>> |----- test.py You can assign entry_point='src/inference.py'. - git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch' - and 'commit' (default: None). - 'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If - 'commit' is not specified, the latest commit in the required branch will be used. + git_config (dict[str, str]): Git configurations used for cloning files, including ``repo``, ``branch``, + ``commit``, ``2FA_enabled``, ``username``, ``password`` and ``token`` (default: None). The fields are + optional except ``repo``. If ``branch`` is not specified, master branch will be used. If ``commit`` + is not specified, the latest commit in the required branch will be used. 'branch' and 'commit' are + optional. If 'branch' is not specified, 'master' branch will be used. If 'commit' is not specified, + the latest commit in the required branch will be used. Example: The following config: @@ -530,6 +532,15 @@ def __init__( results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout the specified commit. + ``2FA_enabled``, ``username``, ``password`` and ``token`` are for authentication purpose. + ``2FA_enabled`` must be ``True`` or ``False`` if it is provided. If ``2FA_enabled`` is not provided, + we consider 2FA as disabled. For GitHub and other Git repos, when ssh urls are provided, it does not + make a difference whether 2FA is enabled or disabled; an ssh passphrase should be in local storage. + When https urls are provided: if 2FA is disabled, then either token or username+password will + be used for authentication if provided (token prioritized); if 2FA is enabled, only token will + be used for authentication if provided. If required authentication info is not provided, python SDK + will try to use local credentials storage to authenticate. If that fails either, an error message will + be thrown. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory will be preserved when training on SageMaker. If 'git_config' is provided, @@ -554,13 +565,13 @@ def __init__( Example: The following call - >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env']) + >>> Estimator(entry_point='inference.py', dependencies=['my/libs/common', 'virtual-env']) results in the following inside the container: >>> $ ls >>> opt/ml/code - >>> |------ train.py + >>> |------ inference.py >>> |------ common >>> |------ virtual-env diff --git a/tests/integ/test_git.py b/tests/integ/test_git.py index a941269122..da5579f9b9 100644 --- a/tests/integ/test_git.py +++ b/tests/integ/test_git.py @@ -16,6 +16,7 @@ import numpy import pytest +import subprocess import tempfile from tests.integ import lock as lock @@ -30,6 +31,20 @@ BRANCH = "test-branch-git-config" COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +PRIVATE_GIT_REPO = "https://github.com/git-support-test/test-git.git" +PRIVATE_BRANCH = "master" +PRIVATE_COMMIT = "a46d6f9add3532ca3e4e231e4108b6bad15b7373" + +PRIVATE_GIT_REPO_2FA = "https://github.com/git-support-test-2fa/test-git.git" +PRIVATE_GIT_REPO_2FA_SSH = "git@github.com:git-support-test-2fa/test-git.git" +PRIVATE_BRANCH_2FA = "master" +PRIVATE_COMMIT_2FA = "52381dee030eb332a7e42d9992878d7261eb21d4" + +# Since personal access tokens will delete themselves if they are committed to GitHub repos, +# we cannot hard code them here, but have to encrypt instead +ENCRYPTED_PRIVATE_REPO_TOKEN = "e-4_1-1dc_71-f0e_f7b54a0f3b7db2757163da7b5e8c3" +PRIVATE_REPO_TOKEN = ENCRYPTED_PRIVATE_REPO_TOKEN.replace("-", "").replace("_", "") + # endpoint tests all use the same port, so we use this lock to prevent concurrent execution LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock") @@ -56,7 +71,6 @@ def test_git_support_with_pytorch(sagemaker_local_session): with lock.lock(LOCK_PATH): try: predictor = pytorch.deploy(initial_instance_count=1, instance_type="local") - data = numpy.zeros(shape=(1, 1, 28, 28)).astype(numpy.float32) result = predictor.predict(data) assert result is not None @@ -66,9 +80,17 @@ def test_git_support_with_pytorch(sagemaker_local_session): @pytest.mark.local_mode def test_git_support_with_mxnet(sagemaker_local_session): + script_path = "mnist.py" data_path = os.path.join(DATA_DIR, "mxnet_mnist") - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": False, + "username": "git-support-test", + "password": "passw0rd@ %", + } source_dir = "mxnet" dependencies = ["foo/bar.py"] mx = MXNet( @@ -114,7 +136,6 @@ def test_git_support_with_mxnet(sagemaker_local_session): git_config=git_config, ) predictor = model.deploy(initial_instance_count=1, instance_type="local") - data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) assert result is not None @@ -128,9 +149,11 @@ def test_git_support_with_sklearn(sagemaker_local_session, sklearn_full_version) script_path = "mnist.py" data_path = os.path.join(DATA_DIR, "sklearn_mnist") git_config = { - "repo": "https://github.com/GaryTu1020/python-sdk-testing.git", - "branch": "branch1", - "commit": "aafa4e96237dd78a015d5df22bfcfef46845c3c5", + "repo": PRIVATE_GIT_REPO_2FA, + "branch": PRIVATE_BRANCH_2FA, + "commit": PRIVATE_COMMIT_2FA, + "2FA_enabled": True, + "token": PRIVATE_REPO_TOKEN, } source_dir = "sklearn" sklearn = SKLearn( @@ -171,3 +194,34 @@ def test_git_support_with_sklearn(sagemaker_local_session, sklearn_full_version) assert result is not None finally: predictor.delete_endpoint() + + +@pytest.mark.local_mode +def test_git_support_with_sklearn_ssh_passphrase_not_configured( + sagemaker_local_session, sklearn_full_version +): + script_path = "mnist.py" + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + git_config = { + "repo": PRIVATE_GIT_REPO_2FA_SSH, + "branch": PRIVATE_BRANCH_2FA, + "commit": PRIVATE_COMMIT_2FA, + } + source_dir = "sklearn" + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + source_dir=source_dir, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + framework_version=sklearn_full_version, + hyperparameters={"epochs": 1}, + git_config=git_config, + ) + train_input = "file://" + os.path.join(data_path, "train") + test_input = "file://" + os.path.join(data_path, "test") + with pytest.raises(subprocess.CalledProcessError) as error: + sklearn.fit({"train": train_input, "test": test_input}) + assert "returned non-zero exit status" in str(error) diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index e7bc6d2edf..addcdc5025 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -51,16 +51,11 @@ GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" BRANCH = "test-branch-git-config" COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" - -DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} -INSTANCE_TYPE = "c4.4xlarge" -ACCELERATOR_TYPE = "ml.eia.medium" -ROLE = "DummyRole" -IMAGE_NAME = "fakeimage" -REGION = "us-west-2" -JOB_NAME = "{}-{}".format(IMAGE_NAME, TIMESTAMP) -TAGS = [{"Name": "some-tag", "Value": "value-for-tag"}] -OUTPUT_PATH = "s3://bucket/prefix" +PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git" +PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git" +PRIVATE_BRANCH = "test-branch" +PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +REPO_DIR = "/tmp/repo_dir" DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} @@ -892,9 +887,9 @@ def test_git_support_bad_repo_url_format(sagemaker_session): train_instance_type=INSTANCE_TYPE, enable_cloudwatch_metrics=True, ) - with pytest.raises(subprocess.CalledProcessError) as error: + with pytest.raises(ValueError) as error: fw.fit() - assert "returned non-zero exit status" in str(error) + assert "Invalid Git url provided." in str(error) @patch( @@ -1026,6 +1021,116 @@ def test_git_support_dependencies_not_exist(sagemaker_session): assert "Dependency", "does not exist in the repo." in str(error) +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +def test_git_support_with_username_password_no_2fa(git_clone_repo, sagemaker_session): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "username": "username", + "password": "passw0rd!", + } + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + assert fw.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +def test_git_support_with_token_2fa(git_clone_repo, sagemaker_session): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "token": "my-token", + "2FA_enabled": True, + } + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + assert fw.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +def test_git_support_ssh_no_passphrase_needed(git_clone_repo, sagemaker_session): + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + fw.fit() + git_clone_repo.assert_called_once_with(git_config, entry_point, None, []) + assert fw.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) + ), +) +def test_git_support_ssh_passphrase_required(git_clone_repo, sagemaker_session): + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + entry_point = "entry_point" + fw = DummyFramework( + entry_point=entry_point, + git_config=git_config, + role=ROLE, + sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, + train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True, + ) + with pytest.raises(subprocess.CalledProcessError) as error: + fw.fit() + assert "returned non-zero exit status" in str(error) + + @patch("time.strftime", return_value=TIMESTAMP) def test_init_with_source_dir_s3(strftime, sagemaker_session): fw = DummyFramework( diff --git a/tests/unit/test_git_utils.py b/tests/unit/test_git_utils.py index a862e76704..c97b34207c 100644 --- a/tests/unit/test_git_utils.py +++ b/tests/unit/test_git_utils.py @@ -13,15 +13,20 @@ from __future__ import absolute_import import pytest +import os import subprocess from mock import patch from sagemaker import git_utils REPO_DIR = "/tmp/repo_dir" -GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" -BRANCH = "test-branch-git-config" -COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +PUBLIC_GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +PUBLIC_BRANCH = "test-branch-git-config" +PUBLIC_COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git" +PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git" +PRIVATE_BRANCH = "test-branch" +PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" @patch("subprocess.check_call") @@ -30,55 +35,58 @@ @patch("os.path.isdir", return_value=True) @patch("os.path.exists", return_value=True) def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "bar"] + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" ret = git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) - check_call.assert_any_call(["git", "clone", git_config["repo"], REPO_DIR]) - check_call.assert_any_call(args=["git", "checkout", BRANCH], cwd=REPO_DIR) - check_call.assert_any_call(args=["git", "checkout", COMMIT], cwd=REPO_DIR) + check_call.assert_any_call(["git", "clone", git_config["repo"], REPO_DIR], env=env) + check_call.assert_any_call(args=["git", "checkout", PUBLIC_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PUBLIC_COMMIT], cwd=REPO_DIR) mkdtemp.assert_called_once() assert ret["entry_point"] == "entry_point" assert ret["source_dir"] == "/tmp/repo_dir/source_dir" assert ret["dependencies"] == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] -def test_git_clone_repo_entry_point_not_provided(): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} +def test_git_clone_repo_repo_not_provided(): + git_config = {"branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} + entry_point = "entry_point_that_does_not_exist" source_dir = "source_dir" + dependencies = ["foo", "bar"] with pytest.raises(ValueError) as error: - git_utils.git_clone_repo(git_config=git_config, entry_point=None, source_dir=source_dir) - assert "Please provide an entry point." in str(error) + git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) + assert "Please provide a repo for git_config." in str(error) -@patch("subprocess.check_call") -@patch("tempfile.mkdtemp", return_value=REPO_DIR) -@patch("os.path.isfile", return_value=True) -@patch("os.path.isdir", return_value=True) -@patch("os.path.exists", return_value=True) -def test_git_clone_repo_repo_not_provided(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"branch": BRANCH, "commit": COMMIT} - entry_point = "entry_point_that_does_not_exist" +def test_git_clone_repo_git_argument_wrong_format(): + git_config = { + "repo": PUBLIC_GIT_REPO, + "branch": PUBLIC_BRANCH, + "commit": PUBLIC_COMMIT, + "token": 42, + } + entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "bar"] + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" with pytest.raises(ValueError) as error: git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) - assert "Please provide a repo for git_config." in str(error) + assert "'token' must be a string." in str(error) @patch( "subprocess.check_call", side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone {} {}".format(GIT_REPO, REPO_DIR) + returncode=1, cmd="git clone {} {}".format(PUBLIC_GIT_REPO, REPO_DIR) ), ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -@patch("os.path.isfile", return_value=True) -@patch("os.path.isdir", return_value=True) -@patch("os.path.exists", return_value=True) -def test_git_clone_repo_clone_fail(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} +def test_git_clone_repo_clone_fail(mkdtemp, check_call): + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "bar"] @@ -92,11 +100,8 @@ def test_git_clone_repo_clone_fail(exists, isdir, isfile, mkdtemp, check_call): side_effect=[True, subprocess.CalledProcessError(returncode=1, cmd="git checkout banana")], ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -@patch("os.path.isfile", return_value=True) -@patch("os.path.isdir", return_value=True) -@patch("os.path.exists", return_value=True) -def test_git_clone_repo_branch_not_exist(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} +def test_git_clone_repo_branch_not_exist(mkdtemp, check_call): + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "bar"] @@ -110,15 +115,12 @@ def test_git_clone_repo_branch_not_exist(exists, isdir, isfile, mkdtemp, check_c side_effect=[ True, True, - subprocess.CalledProcessError(returncode=1, cmd="git checkout {}".format(COMMIT)), + subprocess.CalledProcessError(returncode=1, cmd="git checkout {}".format(PUBLIC_COMMIT)), ], ) @patch("tempfile.mkdtemp", return_value=REPO_DIR) -@patch("os.path.isfile", return_value=True) -@patch("os.path.isdir", return_value=True) -@patch("os.path.exists", return_value=True) -def test_git_clone_repo_commit_not_exist(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} +def test_git_clone_repo_commit_not_exist(mkdtemp, check_call): + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "bar"] @@ -132,8 +134,8 @@ def test_git_clone_repo_commit_not_exist(exists, isdir, isfile, mkdtemp, check_c @patch("os.path.isfile", return_value=False) @patch("os.path.isdir", return_value=True) @patch("os.path.exists", return_value=True) -def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} +def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, heck_call): + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point_that_does_not_exist" source_dir = "source_dir" dependencies = ["foo", "bar"] @@ -148,7 +150,7 @@ def test_git_clone_repo_entry_point_not_exist(exists, isdir, isfile, mkdtemp, ch @patch("os.path.isdir", return_value=False) @patch("os.path.exists", return_value=True) def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir_that_does_not_exist" dependencies = ["foo", "bar"] @@ -163,10 +165,260 @@ def test_git_clone_repo_source_dir_not_exist(exists, isdir, isfile, mkdtemp, che @patch("os.path.isdir", return_value=True) @patch("os.path.exists", side_effect=[True, False]) def test_git_clone_repo_dependencies_not_exist(exists, isdir, isfile, mkdtemp, check_call): - git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + git_config = {"repo": PUBLIC_GIT_REPO, "branch": PUBLIC_BRANCH, "commit": PUBLIC_COMMIT} entry_point = "entry_point" source_dir = "source_dir" dependencies = ["foo", "dep_that_does_not_exist"] with pytest.raises(ValueError) as error: git_utils.git_clone_repo(git_config, entry_point, source_dir, dependencies) assert "does not exist in the repo." in str(error) + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_with_username_password_no_2fa(sfile, mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "username": "username", + "password": "passw0rd!", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + ret = git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + check_call.assert_any_call( + [ + "git", + "clone", + "https://username:passw0rd%21@github.com/testAccount/private-repo.git", + REPO_DIR, + ], + env=env, + ) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_COMMIT], cwd=REPO_DIR) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_with_token_no_2fa(isfile, mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "token": "08c13d80a861f37150cb5c64520bfe14a85ca191", + "2FA_enabled": False, + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + ret = git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + check_call.assert_any_call( + [ + "git", + "clone", + "https://08c13d80a861f37150cb5c64520bfe14a85ca191@github.com/testAccount/private-repo.git", + REPO_DIR, + ], + env=env, + ) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_COMMIT], cwd=REPO_DIR) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_with_token_2fa(isfile, mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": True, + "username": "username", + "token": "08c13d80a861f37150cb5c64520bfe14a85ca191", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + ret = git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + check_call.assert_any_call( + [ + "git", + "clone", + "https://08c13d80a861f37150cb5c64520bfe14a85ca191@github.com/testAccount/private-repo.git", + REPO_DIR, + ], + env=env, + ) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_COMMIT], cwd=REPO_DIR) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_ssh(isfile, mkdtemp, check_call): + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + entry_point = "entry_point" + ret = git_utils.git_clone_repo(git_config, entry_point) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_with_token_no_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "username": "username", + "password": "passw0rd!", + "token": "08c13d80a861f37150cb5c64520bfe14a85ca191", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + with pytest.warns(UserWarning) as warn: + ret = git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + assert ( + "Using token for authentication, other credentials will be ignored." + in warn[0].message.args[0] + ) + check_call.assert_any_call( + [ + "git", + "clone", + "https://08c13d80a861f37150cb5c64520bfe14a85ca191@github.com/testAccount/private-repo.git", + REPO_DIR, + ], + env=env, + ) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_COMMIT], cwd=REPO_DIR) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch("subprocess.check_call") +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +@patch("os.path.isfile", return_value=True) +def test_git_clone_repo_with_token_2fa_unnecessary_creds_provided(isfile, mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": True, + "username": "username", + "token": "08c13d80a861f37150cb5c64520bfe14a85ca191", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + with pytest.warns(UserWarning) as warn: + ret = git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + assert ( + "Using token for authentication, other credentials will be ignored." + in warn[0].message.args[0] + ) + check_call.assert_any_call( + [ + "git", + "clone", + "https://08c13d80a861f37150cb5c64520bfe14a85ca191@github.com/testAccount/private-repo.git", + REPO_DIR, + ], + env=env, + ) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_BRANCH], cwd=REPO_DIR) + check_call.assert_any_call(args=["git", "checkout", PRIVATE_COMMIT], cwd=REPO_DIR) + assert ret["entry_point"] == "/tmp/repo_dir/entry_point" + assert ret["source_dir"] is None + assert ret["dependencies"] is None + + +@patch( + "subprocess.check_call", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO, REPO_DIR) + ), +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +def test_git_clone_repo_with_username_and_password_wrong_creds(mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": False, + "username": "username", + "password": "wrong-password", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + assert "returned non-zero exit status" in str(error) + + +@patch( + "subprocess.check_call", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO, REPO_DIR) + ), +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +def test_git_clone_repo_with_token_wrong_creds(mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": False, + "token": "wrong-token", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + assert "returned non-zero exit status" in str(error) + + +@patch( + "subprocess.check_call", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO, REPO_DIR) + ), +) +@patch("tempfile.mkdtemp", return_value=REPO_DIR) +def test_git_clone_repo_with_and_token_2fa_wrong_creds(mkdtemp, check_call): + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "2FA_enabled": False, + "token": "wrong-token", + } + entry_point = "entry_point" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + with pytest.raises(subprocess.CalledProcessError) as error: + git_utils.git_clone_repo(git_config=git_config, entry_point=entry_point) + assert "returned non-zero exit status" in str(error) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index fa6253f925..0090137a91 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -43,6 +43,11 @@ GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" BRANCH = "test-branch-git-config" COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" +PRIVATE_GIT_REPO_SSH = "git@github.com:testAccount/private-repo.git" +PRIVATE_GIT_REPO = "https://github.com/testAccount/private-repo.git" +PRIVATE_BRANCH = "test-branch" +PRIVATE_COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +REPO_DIR = "/tmp/repo_dir" DESCRIBE_MODEL_PACKAGE_RESPONSE = { @@ -666,3 +671,97 @@ def test_git_support_dependencies_not_exist(sagemaker_session): ) model.prepare_container_def(instance_type=INSTANCE_TYPE) assert "Dependency", "does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_with_username_password_no_2fa( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "username": "username", + "password": "passw0rd!", + } + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_with_token_2fa(tar_and_upload_dir, git_clone_repo, sagemaker_session): + entry_point = "entry_point" + git_config = { + "repo": PRIVATE_GIT_REPO, + "branch": PRIVATE_BRANCH, + "commit": PRIVATE_COMMIT, + "token": "my-token", + "2FA_enabled": True, + } + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=lambda gitconfig, entrypoint, source_dir=None, dependencies=None: { + "entry_point": "/tmp/repo_dir/entry_point", + "source_dir": None, + "dependencies": None, + }, +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_ssh_no_passphrase_needed( + tar_and_upload_dir, git_clone_repo, sagemaker_session +): + entry_point = "entry_point" + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, None, []) + assert model.entry_point == "/tmp/repo_dir/entry_point" + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone {} {}".format(PRIVATE_GIT_REPO_SSH, REPO_DIR) + ), +) +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_ssh_passphrase_required(tar_and_upload_dir, git_clone_repo, sagemaker_session): + entry_point = "entry_point" + git_config = {"repo": PRIVATE_GIT_REPO_SSH, "branch": PRIVATE_BRANCH, "commit": PRIVATE_COMMIT} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error)