|
28 | 28 | SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
|
29 | 29 | IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
|
30 | 30 | MODEL_DATA = "s3://someprefix2/models/model.tar.gz"
|
| 31 | +GIT_REPO = "https://github.com/aws/sagemaker-python-sdk.git" |
| 32 | +BRANCH = "test-branch-git-config" |
| 33 | +COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73" |
31 | 34 |
|
32 | 35 |
|
33 | 36 | class DummyFrameworkModel(FrameworkModel):
|
@@ -97,3 +100,30 @@ def test_script_mode_model_same_calls_as_framework(repack_model, sagemaker_sessi
|
97 | 100 | == sagemaker_session.endpoint_from_production_variants.call_args_list
|
98 | 101 | )
|
99 | 102 | assert generic_model_repack_model_args == repack_model.call_args_list
|
| 103 | + |
| 104 | + |
| 105 | +@patch("sagemaker.git_utils.git_clone_repo") |
| 106 | +@patch("sagemaker.model.fw_utils.tar_and_upload_dir") |
| 107 | +def test_git_support_succeed_model_class(tar_and_upload_dir, git_clone_repo, sagemaker_session): |
| 108 | + git_clone_repo.side_effect = lambda gitconfig, entrypoint, sourcedir, dependency: { |
| 109 | + "entry_point": "entry_point", |
| 110 | + "source_dir": "/tmp/repo_dir/source_dir", |
| 111 | + "dependencies": ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"], |
| 112 | + } |
| 113 | + entry_point = "entry_point" |
| 114 | + source_dir = "source_dir" |
| 115 | + dependencies = ["foo", "bar"] |
| 116 | + git_config = {"repo": GIT_REPO, "branch": BRANCH, "commit": COMMIT} |
| 117 | + model = Model( |
| 118 | + sagemaker_session=sagemaker_session, |
| 119 | + entry_point=entry_point, |
| 120 | + source_dir=source_dir, |
| 121 | + dependencies=dependencies, |
| 122 | + git_config=git_config, |
| 123 | + image_uri=IMAGE_URI, |
| 124 | + ) |
| 125 | + model.prepare_container_def(instance_type=INSTANCE_TYPE) |
| 126 | + git_clone_repo.assert_called_with(git_config, entry_point, source_dir, dependencies) |
| 127 | + assert model.entry_point == "entry_point" |
| 128 | + assert model.source_dir == "/tmp/repo_dir/source_dir" |
| 129 | + assert model.dependencies == ["/tmp/repo_dir/foo", "/tmp/repo_dir/bar"] |
0 commit comments