Skip to content

Commit 1ae7ce9

Browse files
rohangujarathiRohan Gujarathi
and
Rohan Gujarathi
authored
feature: Add integ tests for remote_function, auto_capture functionality (#3841)
Co-authored-by: Rohan Gujarathi <[email protected]>
1 parent 79795af commit 1ae7ce9

File tree

3 files changed

+158
-27
lines changed

3 files changed

+158
-27
lines changed

tests/integ/sagemaker/remote_function/conftest.py

+90-27
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import shutil
1919
import pytest
2020
import docker
21+
import re
2122

2223
from sagemaker.utils import sagemaker_timestamp, _tmpdir, sts_regional_endpoint
2324

@@ -57,25 +58,30 @@
5758
"ENV SAGEMAKER_JOB_CONDA_ENV=default_env\n"
5859
)
5960

60-
CONDA_YML_FILE_TEMPLATE = (
61-
"name: integ_test_env\n"
62-
"channels:\n"
63-
" - defaults\n"
64-
"dependencies:\n"
65-
" - scipy=1.7.3\n"
66-
" - pip:\n"
67-
" - /sagemaker-{sagemaker_version}.tar.gz\n"
68-
"prefix: /opt/conda/bin/conda\n"
61+
AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE = (
62+
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
63+
'SHELL ["/bin/bash", "-c"]\n'
64+
"RUN apt-get update -y \
65+
&& apt-get install -y unzip curl\n\n"
66+
"RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \
67+
&& bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \
68+
&& /opt/conda/bin/conda init bash\n\n"
69+
"ENV PATH $PATH:/opt/conda/bin\n"
70+
"COPY {source_archive} ./\n"
71+
"RUN mamba create -n auto_capture_client python={py_version} -y \
72+
&& mamba run -n auto_capture_client pip install '{source_archive}' awscli boto3\n"
73+
"COPY test_auto_capture.py .\n"
74+
'CMD ["mamba", "run", "-n", "auto_capture_client", "python", "test_auto_capture.py"]\n'
6975
)
7076

71-
CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL = (
77+
CONDA_YML_FILE_TEMPLATE = (
7278
"name: integ_test_env\n"
7379
"channels:\n"
7480
" - defaults\n"
7581
"dependencies:\n"
7682
" - scipy=1.7.3\n"
7783
" - pip:\n"
78-
" - sagemaker-2.132.1.dev0-py2.py3-none-any.whl\n"
84+
" - /sagemaker-{sagemaker_version}.tar.gz\n"
7985
"prefix: /opt/conda/bin/conda\n"
8086
)
8187

@@ -99,6 +105,12 @@ def dummy_container_with_conda(sagemaker_session):
99105
return ecr_uri
100106

101107

108+
@pytest.fixture(scope="package")
109+
def auto_capture_test_container(sagemaker_session):
110+
ecr_uri = _build_auto_capture_client_container("3.10", AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE)
111+
return ecr_uri
112+
113+
102114
@pytest.fixture(scope="package")
103115
def conda_env_yml():
104116
"""Write conda yml file needed for tests"""
@@ -116,22 +128,6 @@ def conda_env_yml():
116128
os.remove(conda_yml_file_name)
117129

118130

119-
@pytest.fixture(scope="package")
120-
def conda_yml_file_sm_from_input_channel():
121-
"""Write conda yml file needed for tests"""
122-
123-
conda_yml_file_name = "conda_env_sm_from_input_channel.yml"
124-
conda_file_path = os.path.join(os.getcwd(), conda_yml_file_name)
125-
126-
with open(conda_file_path, "w") as yml_file:
127-
yml_file.writelines(CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL)
128-
yield conda_file_path
129-
130-
# cleanup
131-
if os.path.isfile(conda_yml_file_name):
132-
os.remove(conda_yml_file_name)
133-
134-
135131
def _build_container(sagemaker_session, py_version, docker_templete):
136132
"""Build a dummy test container locally and push a container to an ecr repo"""
137133

@@ -178,6 +174,25 @@ def _build_container(sagemaker_session, py_version, docker_templete):
178174
return ecr_image
179175

180176

177+
def _build_auto_capture_client_container(py_version, docker_templete):
178+
"""Build a test docker container that will act as a client for auto_capture tests"""
179+
with _tmpdir() as tmpdir:
180+
print("building docker image locally in ", tmpdir)
181+
print("building source archive...")
182+
source_archive = _generate_sdk_tar_with_public_version(tmpdir)
183+
_move_auto_capture_test_file(tmpdir)
184+
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
185+
file.writelines(
186+
docker_templete.format(py_version=py_version, source_archive=source_archive)
187+
)
188+
189+
docker_client = docker.from_env()
190+
191+
print("building docker image...")
192+
image, build_logs = docker_client.images.build(path=tmpdir, tag=REPO_NAME, rm=True)
193+
return image.id
194+
195+
181196
def _is_repository_exists(ecr_client, repo_name):
182197
try:
183198
ecr_client.describe_repositories(repositoryNames=[repo_name])
@@ -212,3 +227,51 @@ def _generate_and_move_sagemaker_sdk_tar(destination_folder):
212227
shutil.copy2(source_path, destination_path)
213228

214229
return source_archive
230+
231+
232+
def _generate_sdk_tar_with_public_version(destination_folder):
233+
"""
234+
This function is used for auto capture integ tests. This test need the sagemaker version
235+
that is already published to PyPI. So we manipulate the current local dev version to change
236+
latest released SDK version.
237+
238+
It does the following
239+
1. Change the dev version of the SDK to the latest published version
240+
2. Generate SDK tar using that version
241+
3. Move tar file to the folder when docker file is present
242+
3. Update the version back to the dev version
243+
"""
244+
dist_folder_path = "dist"
245+
246+
with open(os.path.join(os.getcwd(), "VERSION"), "r+") as version_file:
247+
dev_sagemaker_version = version_file.readline().strip()
248+
public_sagemaker_version = re.sub("1.dev0", "0", dev_sagemaker_version)
249+
version_file.seek(0)
250+
version_file.write(public_sagemaker_version)
251+
version_file.truncate()
252+
if os.path.exists(dist_folder_path):
253+
shutil.rmtree(dist_folder_path)
254+
255+
source_archive = _generate_and_move_sagemaker_sdk_tar(destination_folder)
256+
257+
with open(os.path.join(os.getcwd(), "VERSION"), "r+") as version_file:
258+
version_file.seek(0)
259+
version_file.write(dev_sagemaker_version)
260+
version_file.truncate()
261+
if os.path.exists(dist_folder_path):
262+
shutil.rmtree(dist_folder_path)
263+
264+
return source_archive
265+
266+
267+
def _move_auto_capture_test_file(destination_folder):
268+
"""
269+
Move the test file for autocapture tests to a temp folder along with the docker file.
270+
"""
271+
272+
test_file_name = "test_auto_capture.py"
273+
source_path = os.path.join(
274+
os.getcwd(), "tests", "integ", "sagemaker", "remote_function", test_file_name
275+
)
276+
destination_path = os.path.join(destination_folder, test_file_name)
277+
shutil.copy2(source_path, destination_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pandas as pd
16+
17+
from sagemaker.remote_function import remote
18+
19+
20+
@remote(
21+
role="SageMakerRole",
22+
instance_type="ml.m5.xlarge",
23+
dependencies="auto_capture",
24+
)
25+
def multiply(dataframe: pd.DataFrame, factor: float):
26+
return dataframe * factor
27+
28+
29+
df = pd.DataFrame(
30+
{
31+
"A": [14, 4, 5, 4, 1],
32+
"B": [5, 2, 54, 3, 2],
33+
"C": [20, 20, 7, 3, 8],
34+
"D": [14, 3, 6, 2, 6],
35+
}
36+
)
37+
38+
if __name__ == "__main__":
39+
multiply(df, 10.0)

tests/integ/sagemaker/remote_function/test_decorator.py

+29
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import random
2121
import string
2222
import pandas as pd
23+
import subprocess
24+
import shlex
2325
from sagemaker.experiments.run import Run, load_run
2426
from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources
2527
from sagemaker.experiments.trial_component import _TrialComponent
@@ -596,3 +598,30 @@ def get_file_content(file_names):
596598
with pytest.raises(RuntimeEnvironmentError) as e:
597599
get_file_content(["test_file_1", "test_file_2", "test_file_3"])
598600
assert "line 2: bws: command not found" in str(e)
601+
602+
603+
def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
604+
"""
605+
This test runs a docker container. The Container invocation will execute a python script
606+
with remote function to test auto_capture scenario. The test requires conda to be
607+
installed on the client side which is not available in the code build image. Hence we need
608+
to run the test in another docker container with conda installed.
609+
610+
Any assertion is not needed because if remote function execution fails, docker run comand
611+
will throw an error thus failing this test.
612+
"""
613+
creds = sagemaker_session.boto_session.get_credentials()
614+
region = sagemaker_session.boto_session.region_name
615+
env = {
616+
"AWS_ACCESS_KEY_ID": str(creds.access_key),
617+
"AWS_SECRET_ACCESS_KEY": str(creds.secret_key),
618+
"AWS_SESSION_TOKEN": str(creds.token),
619+
}
620+
cmd = (
621+
f"docker run -e AWS_ACCESS_KEY_ID={env['AWS_ACCESS_KEY_ID']} "
622+
f"-e AWS_SECRET_ACCESS_KEY={env['AWS_SECRET_ACCESS_KEY']} "
623+
f"-e AWS_SESSION_TOKEN={env['AWS_SESSION_TOKEN']} "
624+
f"-e AWS_DEFAULT_REGION={region} "
625+
f"--rm {auto_capture_test_container}"
626+
)
627+
subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8")

0 commit comments

Comments
 (0)