diff --git a/doc/overview.rst b/doc/overview.rst index c9d7ce3c63..d5bef646f5 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -799,6 +799,23 @@ After that, invoke the ``deploy()`` method on the ``Model``: This returns a predictor the same way an ``Estimator`` does when ``deploy()`` is called. You can now get inferences just like with any other model deployed on Amazon SageMaker. +Git support is also available when you bring your own model, through which you can use inference scripts stored in your +Git repositories. The process is similar to using Git support for training jobs. You can simply provide ``git_config`` +when create the ``Model`` object, and let ``entry_point``, ``source_dir`` and ``dependencies`` (if needed) be relative +paths inside the Git repository: + +.. code:: python + + git_config = {'repo': 'https://github.com/username/repo-with-training-scripts.git', + 'branch': 'branch1', + 'commit': '4893e528afa4a790331e1b5286954f073b0f14a2'} + + sagemaker_model = MXNetModel(model_data='s3://path/to/model.tar.gz', + role='arn:aws:iam::accid:sagemaker-role', + entry_point='inference.py', + source_dir='mxnet', + git_config=git_config) + A full example is available in the `Amazon SageMaker examples repository `__. You can also find this notebook in the **Advanced Functionality** section of the **SageMaker Examples** section in a notebook instance. diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index caa1a05c5b..e558822dd7 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -955,8 +955,8 @@ def __init__( code_location=None, image_name=None, dependencies=None, - git_config=None, enable_network_isolation=False, + git_config=None, **kwargs ): """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()`` @@ -993,7 +993,7 @@ def __init__( source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this directory are preserved when training on Amazon SageMaker. If 'git_config' is provided, - source_dir should be a relative location to a directory in the Git repo. + 'source_dir' should be a relative location to a directory in the Git repo. Example: With the following GitHub repo directory structure: @@ -1023,6 +1023,8 @@ def __init__( dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. + If 'git_config' is provided, 'dependencies' should be a list of relative locations to directories + with any additional libraries needed in the Git repo. Example: The following call @@ -1085,12 +1087,12 @@ def _prepare_for_training(self, job_name=None): super(Framework, self)._prepare_for_training(job_name=job_name) if self.git_config: - updates = git_utils.git_clone_repo( + updated_paths = git_utils.git_clone_repo( self.git_config, self.entry_point, self.source_dir, self.dependencies ) - self.entry_point = updates["entry_point"] - self.source_dir = updates["source_dir"] - self.dependencies = updates["dependencies"] + self.entry_point = updated_paths["entry_point"] + self.source_dir = updated_paths["source_dir"] + self.dependencies = updated_paths["dependencies"] # validate source dir will raise a ValueError if there is something wrong with the # source directory. We are intentionally not handling it because this is a critical error. diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 44c91d61ff..7360c31e55 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import os +import six import subprocess import tempfile @@ -39,35 +40,46 @@ def git_clone_repo(git_config, entry_point, source_dir=None, dependencies=None): 3. failed to checkout the required commit ValueError: If 1. entry point specified does not exist in the repo 2. source dir specified does not exist in the repo + 3. dependencies specified do not exist in the repo + 4. git_config is in bad format Returns: dict: A dict that contains the updated values of entry_point, source_dir and dependencies """ + if entry_point is None: + raise ValueError("Please provide an entry point.") _validate_git_config(git_config) repo_dir = tempfile.mkdtemp() subprocess.check_call(["git", "clone", git_config["repo"], repo_dir]) _checkout_branch_and_commit(git_config, repo_dir) - ret = {"entry_point": entry_point, "source_dir": source_dir, "dependencies": dependencies} + updated_paths = { + "entry_point": entry_point, + "source_dir": source_dir, + "dependencies": dependencies, + } + # check if the cloned repo contains entry point, source directory and dependencies if source_dir: if not os.path.isdir(os.path.join(repo_dir, source_dir)): raise ValueError("Source directory does not exist in the repo.") if not os.path.isfile(os.path.join(repo_dir, source_dir, entry_point)): raise ValueError("Entry point does not exist in the repo.") - ret["source_dir"] = os.path.join(repo_dir, source_dir) + updated_paths["source_dir"] = os.path.join(repo_dir, source_dir) else: - if not os.path.isfile(os.path.join(repo_dir, entry_point)): + if os.path.isfile(os.path.join(repo_dir, entry_point)): + updated_paths["entry_point"] = os.path.join(repo_dir, entry_point) + else: raise ValueError("Entry point does not exist in the repo.") - ret["entry_point"] = os.path.join(repo_dir, entry_point) - ret["dependencies"] = [] + updated_paths["dependencies"] = [] for path in dependencies: - if not os.path.exists(os.path.join(repo_dir, path)): + if os.path.exists(os.path.join(repo_dir, path)): + updated_paths["dependencies"].append(os.path.join(repo_dir, path)) + else: raise ValueError("Dependency {} does not exist in the repo.".format(path)) - ret["dependencies"].append(os.path.join(repo_dir, path)) - return ret + return updated_paths def _validate_git_config(git_config): @@ -84,6 +96,13 @@ def _validate_git_config(git_config): """ if "repo" not in git_config: raise ValueError("Please provide a repo for git_config.") + allowed_keys = ["repo", "branch", "commit"] + for key in allowed_keys: + if key in git_config and not isinstance(git_config[key], six.string_types): + raise ValueError("'{}' should be a string".format(key)) + for key in git_config: + if key not in allowed_keys: + raise ValueError("Unexpected argument(s) provided for git_config!") def _checkout_branch_and_commit(git_config, repo_dir): @@ -95,8 +114,8 @@ def _checkout_branch_and_commit(git_config, repo_dir): repo_dir (str): the directory where the repo is cloned Raises: - ValueError: If 1. entry point specified does not exist in the repo - 2. source dir specified does not exist in the repo + CalledProcessError: If 1. failed to checkout the required branch + 2. failed to checkout the required commit """ if "branch" in git_config: subprocess.check_call(args=["git", "checkout", git_config["branch"]], cwd=str(repo_dir)) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 08c03b12fb..27738f1f88 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -17,7 +17,7 @@ import os import sagemaker -from sagemaker import fw_utils, local, session, utils +from sagemaker import fw_utils, local, session, utils, git_utils from sagemaker.fw_utils import UploadedCode from sagemaker.transformer import Transformer @@ -494,6 +494,7 @@ def __init__( code_location=None, sagemaker_session=None, dependencies=None, + git_config=None, **kwargs ): """Initialize a ``FrameworkModel``. @@ -504,15 +505,54 @@ def __init__( role (str): An IAM role name or ARN for SageMaker to access AWS resources on your behalf. entry_point (str): Path (absolute or relative) to the Python source file which should be executed as the entry point to model hosting. This should be compatible with either Python 2.7 or Python 3.5. + If 'git_config' is provided, 'entry_point' should be a relative location to the Python source file in + the Git repo. + Example: + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='src/inference.py'. + git_config (dict[str, str]): Git configurations used for cloning files, including 'repo', 'branch' + and 'commit' (default: None). + 'branch' and 'commit' are optional. If 'branch' is not specified, 'master' branch will be used. If + 'commit' is not specified, the latest commit in the required branch will be used. + Example: + + The following config: + + >>> git_config = {'repo': 'https://github.com/aws/sagemaker-python-sdk.git', + >>> 'branch': 'test-branch-git-config', + >>> 'commit': '329bfcf884482002c05ff7f44f62599ebc9f445a'} + + results in cloning the repo specified in 'repo', then checkout the 'master' branch, and checkout + the specified commit. source_dir (str): Path (absolute or relative) to a directory with any other training source code dependencies aside from the entry point file (default: None). Structure within this - directory will be preserved when training on SageMaker. - If the directory points to S3, no code will be uploaded and the S3 location will be used instead. + directory will be preserved when training on SageMaker. If 'git_config' is provided, + 'source_dir' should be a relative location to a directory in the Git repo. If the directory points + to S3, no code will be uploaded and the S3 location will be used instead. + Example: + + With the following GitHub repo directory structure: + + >>> |----- README.md + >>> |----- src + >>> |----- inference.py + >>> |----- test.py + + You can assign entry_point='inference.py', source_dir='src'. dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be copied to SageMaker in the same folder where the entrypoint is copied. - If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used - instead. Example: + If 'git_config' is provided, 'dependencies' should be a list of relative locations to directories + with any additional libraries needed in the Git repo. If the ```source_dir``` points to S3, code + will be uploaded and the S3 location will be used instead. + Example: The following call >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env']) @@ -554,12 +594,20 @@ def __init__( self.entry_point = entry_point self.source_dir = source_dir self.dependencies = dependencies or [] + self.git_config = git_config self.enable_cloudwatch_metrics = enable_cloudwatch_metrics self.container_log_level = container_log_level if code_location: self.bucket, self.key_prefix = fw_utils.parse_s3_url(code_location) else: self.bucket, self.key_prefix = None, None + if self.git_config: + updates = git_utils.git_clone_repo( + self.git_config, self.entry_point, self.source_dir, self.dependencies + ) + self.entry_point = updates["entry_point"] + self.source_dir = updates["source_dir"] + self.dependencies = updates["dependencies"] self.uploaded_code = None self.repacked_model_data = None diff --git a/tests/integ/test_git.py b/tests/integ/test_git.py index 7d5c1b76ae..a941269122 100644 --- a/tests/integ/test_git.py +++ b/tests/integ/test_git.py @@ -21,11 +21,14 @@ from tests.integ import lock as lock from sagemaker.mxnet.estimator import MXNet from sagemaker.pytorch.estimator import PyTorch +from sagemaker.sklearn.estimator import SKLearn +from sagemaker.mxnet.model import MXNetModel +from sagemaker.sklearn.model import SKLearnModel from tests.integ import DATA_DIR, PYTHON_VERSION GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" BRANCH = "test-branch-git-config" -COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" # endpoint tests all use the same port, so we use this lock to prevent concurrent execution LOCK_PATH = os.path.join(tempfile.gettempdir(), "sagemaker_test_git_lock") @@ -62,15 +65,16 @@ def test_git_support_with_pytorch(sagemaker_local_session): @pytest.mark.local_mode -def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version): +def test_git_support_with_mxnet(sagemaker_local_session): script_path = "mnist.py" data_path = os.path.join(DATA_DIR, "mxnet_mnist") git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + source_dir = "mxnet" dependencies = ["foo/bar.py"] mx = MXNet( entry_point=script_path, role="SageMakerRole", - source_dir="mxnet", + source_dir=source_dir, dependencies=dependencies, framework_version=MXNet.LATEST_VERSION, py_version=PYTHON_VERSION, @@ -94,10 +98,76 @@ def test_git_support_with_mxnet(sagemaker_local_session, mxnet_full_version): with lock.lock(LOCK_PATH): try: - predictor = mx.deploy(initial_instance_count=1, instance_type="local") + serving_script_path = "mnist_hosting_with_custom_handlers.py" + client = sagemaker_local_session.sagemaker_client + desc = client.describe_training_job(TrainingJobName=mx.latest_training_job.name) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + model = MXNetModel( + model_data, + "SageMakerRole", + entry_point=serving_script_path, + source_dir=source_dir, + dependencies=dependencies, + py_version=PYTHON_VERSION, + sagemaker_session=sagemaker_local_session, + framework_version=MXNet.LATEST_VERSION, + git_config=git_config, + ) + predictor = model.deploy(initial_instance_count=1, instance_type="local") data = numpy.zeros(shape=(1, 1, 28, 28)) result = predictor.predict(data) assert result is not None finally: predictor.delete_endpoint() + + +@pytest.mark.skipif(PYTHON_VERSION != "py3", reason="Scikit-learn image supports only python 3.") +@pytest.mark.local_mode +def test_git_support_with_sklearn(sagemaker_local_session, sklearn_full_version): + script_path = "mnist.py" + data_path = os.path.join(DATA_DIR, "sklearn_mnist") + git_config = { + "repo": "https://github.com/GaryTu1020/python-sdk-testing.git", + "branch": "branch1", + "commit": "aafa4e96237dd78a015d5df22bfcfef46845c3c5", + } + source_dir = "sklearn" + sklearn = SKLearn( + entry_point=script_path, + role="SageMakerRole", + source_dir=source_dir, + py_version=PYTHON_VERSION, + train_instance_count=1, + train_instance_type="local", + sagemaker_session=sagemaker_local_session, + framework_version=sklearn_full_version, + hyperparameters={"epochs": 1}, + git_config=git_config, + ) + train_input = "file://" + os.path.join(data_path, "train") + test_input = "file://" + os.path.join(data_path, "test") + sklearn.fit({"train": train_input, "test": test_input}) + + assert os.path.isdir(sklearn.source_dir) + + with lock.lock(LOCK_PATH): + try: + client = sagemaker_local_session.sagemaker_client + desc = client.describe_training_job(TrainingJobName=sklearn.latest_training_job.name) + model_data = desc["ModelArtifacts"]["S3ModelArtifacts"] + model = SKLearnModel( + model_data, + "SageMakerRole", + entry_point=script_path, + source_dir=source_dir, + sagemaker_session=sagemaker_local_session, + git_config=git_config, + ) + predictor = model.deploy(1, "local") + + data = numpy.zeros((100, 784), dtype="float32") + result = predictor.predict(data) + assert result is not None + finally: + predictor.delete_endpoint() diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index d97bebe3c0..e7bc6d2edf 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -50,7 +50,7 @@ OUTPUT_PATH = "s3://bucket/prefix" GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" BRANCH = "test-branch-git-config" -COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}} INSTANCE_TYPE = "c4.4xlarge" @@ -898,12 +898,12 @@ def test_git_support_bad_repo_url_format(sagemaker_session): @patch( - "subprocess.check_call", + "sagemaker.git_utils.git_clone_repo", side_effect=subprocess.CalledProcessError( - returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git" + returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" ), ) -def test_git_support_git_clone_fail(check_call, sagemaker_session): +def test_git_support_git_clone_fail(sagemaker_session): git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} fw = DummyFramework( entry_point="entry_point", diff --git a/tests/unit/test_git_utils.py b/tests/unit/test_git_utils.py index 02fb2f43e1..a862e76704 100644 --- a/tests/unit/test_git_utils.py +++ b/tests/unit/test_git_utils.py @@ -21,7 +21,7 @@ REPO_DIR = "/tmp/repo_dir" GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" BRANCH = "test-branch-git-config" -COMMIT = "329bfcf884482002c05ff7f44f62599ebc9f445a" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" @patch("subprocess.check_call") @@ -44,6 +44,14 @@ def test_git_clone_repo_succeed(exists, isdir, isfile, mkdtemp, check_call): assert ret["dependencies"] == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] +def test_git_clone_repo_entry_point_not_provided(): + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + source_dir = "source_dir" + with pytest.raises(ValueError) as error: + git_utils.git_clone_repo(git_config=git_config, entry_point=None, source_dir=source_dir) + assert "Please provide an entry point." in str(error) + + @patch("subprocess.check_call") @patch("tempfile.mkdtemp", return_value=REPO_DIR) @patch("os.path.isfile", return_value=True) diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 7838266516..fa6253f925 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -14,6 +14,7 @@ import copy import os +import subprocess import sagemaker from sagemaker.model import FrameworkModel, ModelPackage @@ -39,6 +40,9 @@ IMAGE_NAME = "fakeimage" REGION = "us-west-2" MODEL_NAME = "{}-{}".format(MODEL_IMAGE, TIMESTAMP) +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" +BRANCH = "test-branch-git-config" +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" DESCRIBE_MODEL_PACKAGE_RESPONSE = { @@ -94,6 +98,21 @@ def create_predictor(self, endpoint_name): return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) +class DummyFrameworkModelForGit(FrameworkModel): + def __init__(self, sagemaker_session, entry_point, **kwargs): + super(DummyFrameworkModelForGit, self).__init__( + MODEL_DATA, + MODEL_IMAGE, + ROLE, + entry_point=entry_point, + sagemaker_session=sagemaker_session, + **kwargs + ) + + def create_predictor(self, endpoint_name): + return RealTimePredictor(endpoint_name, sagemaker_session=self.sagemaker_session) + + @pytest.fixture() def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) @@ -506,3 +525,144 @@ def test_check_neo_region(sagemaker_session, tmpdir): assert model.check_neo_region(region_name) is True else: assert model.check_neo_region(region_name) is False + + +@patch("sagemaker.git_utils.git_clone_repo") +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") +def test_git_support_succeed(tar_and_upload_dir, git_clone_repo, sagemaker_session): + git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { + "entry_point": "entry_point", + "source_dir": "/tmp/repo_dir/source_dir", + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], + } + entry_point = "entry_point" + source_dir = "source_dir" + dependencies = ["foo", "bar"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + dependencies=dependencies, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) + assert model.entry_point == "entry_point" + assert model.source_dir == "/tmp/repo_dir/source_dir" + assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] + + +def test_git_support_repo_not_provided(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Please provide a repo for git_config." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git clone https://github.com/aws/no-such-repo.git /tmp/repo_dir" + ), +) +def test_git_support_git_clone_fail(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": "https://github.com/aws/no-such-repo.git", "branch": BRANCH} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout branch-that-does-not-exist" + ), +) +def test_git_support_branch_not_exist(git_clone_repo, sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": "branch-that-does-not-exist", "commit": COMMIT} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=subprocess.CalledProcessError( + returncode=1, cmd="git checkout commit-sha-that-does-not-exist" + ), +) +def test_git_support_commit_not_exist(git_clone_repo, sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": "commit-sha-that-does-not-exist"} + with pytest.raises(subprocess.CalledProcessError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "returned non-zero exit status" in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Entry point does not exist in the repo."), +) +def test_git_support_entry_point_not_exist(sagemaker_session): + entry_point = "source_dir/entry_point" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, entry_point=entry_point, git_config=git_config + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Entry point does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Source directory does not exist in the repo."), +) +def test_git_support_source_dir_not_exist(sagemaker_session): + entry_point = "entry_point" + source_dir = "source_dir_that_does_not_exist" + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + source_dir=source_dir, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Source directory does not exist in the repo." in str(error) + + +@patch( + "sagemaker.git_utils.git_clone_repo", + side_effect=ValueError("Dependency no-such-dir does not exist in the repo."), +) +def test_git_support_dependencies_not_exist(sagemaker_session): + entry_point = "entry_point" + dependencies = ["foo", "no_such_dir"] + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} + with pytest.raises(ValueError) as error: + model = DummyFrameworkModelForGit( + sagemaker_session=sagemaker_session, + entry_point=entry_point, + dependencies=dependencies, + git_config=git_config, + ) + model.prepare_container_def(instance_type=INSTANCE_TYPE) + assert "Dependency", "does not exist in the repo." in str(error)