Skip to content

Commit c814b7f

Browse files
ZhankuilNamrata Madan
authored and
Namrata Madan
committed
pathway: Update remote function test setup
* Update to remote function test image setup. * Add basic test scenarios. Happy cases are enabled and can run successfully
1 parent 07a68e1 commit c814b7f

File tree

7 files changed

+175
-188
lines changed

7 files changed

+175
-188
lines changed

tests/data/remote_function/containers/without_error/Dockerfile

-22
This file was deleted.

tests/data/remote_function/containers/without_error/serve

-29
This file was deleted.

tests/data/remote_function/containers/without_error/train

-20
This file was deleted.

tests/integ/sagemaker/remote_function/conftest.py

+48-23
Original file line numberDiff line numberDiff line change
@@ -19,55 +19,78 @@
1919
import pytest
2020
import docker
2121

22-
from tests.integ import DATA_DIR
2322
from sagemaker import utils
24-
from sagemaker.utils import sagemaker_timestamp
23+
from sagemaker.utils import sagemaker_timestamp, _tmpdir
2524

2625
REPO_ACCOUNT_ID = "033110030271"
2726

27+
REPO_NAME = "remote-function-dummy-container"
2828

29-
@pytest.fixture(scope="module")
29+
DOCKERFILE_TEMPLATE = (
30+
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
31+
"WORKDIR /opt/ml/remote_function/\n"
32+
"COPY {source_archive} ./\n"
33+
"RUN pip3 install '{source_archive}[remote_function]'\n"
34+
"RUN rm {source_archive}\n"
35+
)
36+
37+
38+
@pytest.fixture(scope="package")
3039
def dummy_container_without_error(sagemaker_session):
31-
repository_name = "remote-function-dummy-container"
32-
docker_file_path = os.path.join(DATA_DIR, "remote_function", "containers", "without_error")
33-
ecr_uri = _build_container(sagemaker_session, repository_name, docker_file_path)
40+
# TODO: the python version should be dynamically specified instead of hardcoding
41+
ecr_uri = _build_container(sagemaker_session, "3.10")
3442
return ecr_uri
3543

3644

37-
def _build_container(sagemaker_session, repository_name, docker_file_path):
38-
"""Build a dummy test container locally and push a container to an ecr repo"""
45+
@pytest.fixture(scope="package")
46+
def dummy_container_incompatible_python_runtime(sagemaker_session):
47+
ecr_uri = _build_container(sagemaker_session, "3.7")
48+
return ecr_uri
49+
3950

40-
_generate_and_move_sagemaker_sdk_tar(docker_file_path)
51+
def _build_container(sagemaker_session, py_version):
52+
"""Build a dummy test container locally and push a container to an ecr repo"""
4153

42-
image_tag = sagemaker_timestamp()
54+
region = sagemaker_session.boto_region_name
55+
image_tag = f"{py_version.replace('.', '-')}-{sagemaker_timestamp()}"
4356
ecr_client = sagemaker_session.boto_session.client("ecr")
4457
username, password = _ecr_login(ecr_client)
4558

46-
docker_client = docker.from_env()
47-
# build docker locally
48-
image, _ = docker_client.images.build(path=docker_file_path, tag=repository_name, rm=True)
59+
with _tmpdir() as tmpdir:
60+
print("building docker image locally in ", tmpdir)
61+
print("building source archive...")
62+
source_archive = _generate_and_move_sagemaker_sdk_tar(tmpdir)
63+
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
64+
file.writelines(
65+
DOCKERFILE_TEMPLATE.format(py_version=py_version, source_archive=source_archive)
66+
)
4967

50-
region = sagemaker_session.boto_region_name
51-
ecr_image = ""
52-
if _is_repository_exists(ecr_client, repository_name):
68+
docker_client = docker.from_env()
69+
70+
print("building docker image...")
71+
image, build_logs = docker_client.images.build(path=tmpdir, tag=REPO_NAME, rm=True)
72+
73+
if _is_repository_exists(ecr_client, REPO_NAME):
5374
sts_client = sagemaker_session.boto_session.client(
5475
"sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region)
5576
)
5677
account_id = sts_client.get_caller_identity()["Account"]
57-
# When the test is run locally, repo will exists in same account whose credentials are used to run the test
78+
# When the test is run locally, repo will exist in same account whose credentials are used to run the test
5879
ecr_image = _ecr_image_uri(
59-
account_id, sagemaker_session.boto_region_name, repository_name, image_tag
80+
account_id, sagemaker_session.boto_region_name, REPO_NAME, image_tag
6081
)
6182
else:
6283
ecr_image = _ecr_image_uri(
6384
REPO_ACCOUNT_ID,
6485
sagemaker_session.boto_region_name,
65-
repository_name,
86+
REPO_NAME,
6687
image_tag,
6788
)
6889

90+
print("pushing image...")
6991
image.tag(ecr_image, tag=image_tag)
7092
docker_client.images.push(ecr_image, auth_config={"username": username, "password": password})
93+
7194
return ecr_image
7295

7396

@@ -99,7 +122,9 @@ def _generate_and_move_sagemaker_sdk_tar(destination_folder):
99122
"""
100123
subprocess.run("python3 setup.py sdist", shell=True)
101124
dist_dir = "dist"
102-
for item in os.listdir(dist_dir):
103-
source = os.path.join(dist_dir, item)
104-
destination = os.path.join(destination_folder, item)
105-
shutil.copy2(source, destination)
125+
source_archive = os.listdir(dist_dir)[0]
126+
source_path = os.path.join(dist_dir, source_archive)
127+
destination_path = os.path.join(destination_folder, source_archive)
128+
shutil.copy2(source_path, destination_path)
129+
130+
return source_archive

tests/integ/sagemaker/remote_function/test_decorator.py

+116-18
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,136 @@
1515

1616
import pytest
1717

18-
import sagemaker.exceptions
19-
from sagemaker import image_uris
2018
from sagemaker.remote_function import remote
2119

20+
from tests.integ.kms_utils import get_or_create_kms_key
21+
2222
ROLE = "SageMakerRole"
2323

2424

2525
@pytest.fixture(scope="module")
26-
def image_uri(
27-
sklearn_latest_version,
28-
sklearn_latest_py_version,
29-
cpu_instance_type,
30-
sagemaker_session,
26+
def s3_kms_key(sagemaker_session):
27+
return get_or_create_kms_key(sagemaker_session=sagemaker_session)
28+
29+
30+
def test_decorator(sagemaker_session, dummy_container_without_error, cpu_instance_type):
31+
@remote(
32+
role=ROLE,
33+
image_uri=dummy_container_without_error,
34+
instance_type=cpu_instance_type,
35+
sagemaker_session=sagemaker_session,
36+
keep_alive_period_in_seconds=30,
37+
)
38+
def divide(x, y):
39+
return x / y
40+
41+
assert divide(10, 2) == 5
42+
assert divide(20, 2) == 10
43+
44+
45+
@pytest.mark.skip
46+
def test_decorated_function_raises_exception(
47+
sagemaker_session, dummy_container_without_error, cpu_instance_type
48+
):
49+
@remote(
50+
role=ROLE,
51+
image_uri=dummy_container_without_error,
52+
instance_type=cpu_instance_type,
53+
sagemaker_session=sagemaker_session,
54+
)
55+
def divide(x, y):
56+
return x / y
57+
58+
with pytest.raises(ZeroDivisionError):
59+
divide(10, 0)
60+
61+
62+
@pytest.mark.skip
63+
def test_remote_python_runtime_is_incompatible(
64+
sagemaker_session, dummy_container_incompatible_python_runtime, cpu_instance_type
65+
):
66+
@remote(
67+
role=ROLE,
68+
image_uri=dummy_container_incompatible_python_runtime,
69+
instance_type=cpu_instance_type,
70+
sagemaker_session=sagemaker_session,
71+
)
72+
def divide(x, y):
73+
return x / y
74+
75+
# TODO: should raise serialization error
76+
with pytest.raises(RuntimeError):
77+
divide(10, 2)
78+
79+
80+
@pytest.mark.skip
81+
def test_advanced_job_setting(
82+
sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key
83+
):
84+
@remote(
85+
role=ROLE,
86+
image_uri=dummy_container_without_error,
87+
instance_type=cpu_instance_type,
88+
# TODO: add VPC settings
89+
s3_kms_key=s3_kms_key,
90+
sagemaker_session=sagemaker_session,
91+
)
92+
def divide(x, y):
93+
return x / y
94+
95+
assert divide(10, 2) == 5
96+
97+
98+
@pytest.mark.skip
99+
def test_with_additional_dependencies(
100+
sagemaker_session, dummy_container_without_error, cpu_instance_type
101+
):
102+
@remote(
103+
role=ROLE,
104+
image_uri=dummy_container_without_error,
105+
dependencies="./requirements.txt",
106+
instance_type=cpu_instance_type,
107+
sagemaker_session=sagemaker_session,
108+
)
109+
def divide(x, y):
110+
return x / y
111+
112+
assert divide(10, 2) == 5
113+
114+
115+
@pytest.mark.skip
116+
def test_with_non_existent_dependencies(
117+
sagemaker_session, dummy_container_without_error, cpu_instance_type
31118
):
32-
return image_uris.retrieve(
33-
"sklearn",
34-
sagemaker_session.boto_region_name,
35-
version=sklearn_latest_version,
36-
py_version=sklearn_latest_py_version,
119+
@remote(
120+
role=ROLE,
121+
image_uri=dummy_container_without_error,
122+
dependencies="./requirements.txt",
37123
instance_type=cpu_instance_type,
124+
sagemaker_session=sagemaker_session,
38125
)
126+
def divide(x, y):
127+
return x / y
128+
129+
# TODO: this should raise RuntimeEnvironmentError
130+
with pytest.raises(RuntimeError):
131+
divide(10, 2)
39132

40133

41-
def test_decorator(sagemaker_session, image_uri, cpu_instance_type):
134+
@pytest.mark.skip
135+
def test_with_incompatible_dependencies(
136+
sagemaker_session, dummy_container_without_error, cpu_instance_type
137+
):
42138
@remote(
43139
role=ROLE,
44-
image_uri=image_uri,
140+
image_uri=dummy_container_without_error,
141+
dependencies="./requirements.txt",
45142
instance_type=cpu_instance_type,
46143
sagemaker_session=sagemaker_session,
47144
)
48-
def square(x):
49-
return x * x
145+
def divide(x, y):
146+
return x / y
50147

51-
with pytest.raises(sagemaker.exceptions.UnexpectedStatusException):
52-
square(10)
148+
# TODO: this should raise DeserializationError
149+
with pytest.raises(RuntimeError):
150+
divide(10, 2)

tests/integ/sagemaker/remote_function/test_executor.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,29 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pytest
1615

17-
import sagemaker.exceptions
18-
from sagemaker import image_uris
1916
from sagemaker.remote_function import RemoteExecutor
2017

2118
ROLE = "SageMakerRole"
2219

2320

24-
@pytest.fixture(scope="module")
25-
def image_uri(
26-
sklearn_latest_version,
27-
sklearn_latest_py_version,
28-
cpu_instance_type,
29-
sagemaker_session,
30-
):
31-
return image_uris.retrieve(
32-
"sklearn",
33-
sagemaker_session.boto_region_name,
34-
version=sklearn_latest_version,
35-
py_version=sklearn_latest_py_version,
36-
instance_type=cpu_instance_type,
37-
)
38-
39-
40-
def test_executor(sagemaker_session, image_uri, cpu_instance_type):
21+
def test_executor(sagemaker_session, dummy_container_without_error, cpu_instance_type):
4122
def square(x):
4223
return x * x
4324

25+
def cube(x):
26+
return x * x * x
27+
4428
with RemoteExecutor(
29+
max_parallel_job=1,
4530
role=ROLE,
46-
image_uri=image_uri,
31+
image_uri=dummy_container_without_error,
4732
instance_type=cpu_instance_type,
4833
sagemaker_session=sagemaker_session,
34+
keep_alive_period_in_seconds=30,
4935
) as e:
50-
future = e.submit(square, 10)
36+
future_1 = e.submit(square, 10)
37+
future_2 = e.submit(cube, 10)
5138

52-
with pytest.raises(sagemaker.exceptions.UnexpectedStatusException):
53-
future.result()
39+
assert future_1.result() == 100
40+
assert future_2.result() == 1000

0 commit comments

Comments
 (0)