|
26 | 26 | from sagemaker.sklearn.model import SKLearnModel
|
27 | 27 | from sagemaker.tensorflow.model import TensorFlowModel
|
28 | 28 | from sagemaker.xgboost.model import XGBoostModel
|
| 29 | +from sagemaker.workflow.properties import Properties |
| 30 | + |
29 | 31 |
|
30 | 32 | MODEL_DATA = "s3://bucket/model.tar.gz"
|
31 | 33 | MODEL_IMAGE = "mi"
|
|
42 | 44 | BRANCH = "test-branch-git-config"
|
43 | 45 | COMMIT = "ae15c9d7d5b97ea95ea451e4662ee43da3401d73"
|
44 | 46 | ENTRY_POINT_INFERENCE = "inference.py"
|
45 |
| - |
46 | 47 | SCRIPT_URI = "s3://codebucket/someprefix/sourcedir.tar.gz"
|
47 | 48 | IMAGE_URI = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.9.0-gpu-py38"
|
48 | 49 |
|
@@ -71,6 +72,23 @@ def sagemaker_session():
|
71 | 72 | return sms
|
72 | 73 |
|
73 | 74 |
|
| 75 | +@patch("shutil.rmtree", MagicMock()) |
| 76 | +@patch("tarfile.open", MagicMock()) |
| 77 | +@patch("os.listdir", MagicMock(return_value=[ENTRY_POINT_INFERENCE])) |
| 78 | +def test_prepare_container_def_with_model_src_s3_returns_correct_url(sagemaker_session): |
| 79 | + model = Model( |
| 80 | + entry_point=ENTRY_POINT_INFERENCE, |
| 81 | + role=ROLE, |
| 82 | + sagemaker_session=sagemaker_session, |
| 83 | + source_dir=SCRIPT_URI, |
| 84 | + image_uri=MODEL_IMAGE, |
| 85 | + model_data=Properties("Steps.MyStep"), |
| 86 | + ) |
| 87 | + container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") |
| 88 | + |
| 89 | + assert container_def["Environment"]["SAGEMAKER_SUBMIT_DIRECTORY"] == SCRIPT_URI |
| 90 | + |
| 91 | + |
74 | 92 | def test_prepare_container_def_with_model_data():
|
75 | 93 | model = Model(MODEL_IMAGE)
|
76 | 94 | container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium")
|
|
0 commit comments