Skip to content

Commit 807bc73

Browse files
author
Rohan Gujarathi
committed
feature: add integ tests for remote_function auto_capture functionality
1 parent 4844aa1 commit 807bc73

File tree

3 files changed

+126
-27
lines changed

3 files changed

+126
-27
lines changed

tests/integ/sagemaker/remote_function/conftest.py

+67-27
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import shutil
1919
import pytest
2020
import docker
21+
import re
2122

2223
from sagemaker.utils import sagemaker_timestamp, _tmpdir, sts_regional_endpoint
2324

@@ -57,25 +58,30 @@
5758
"ENV SAGEMAKER_JOB_CONDA_ENV=default_env\n"
5859
)
5960

60-
CONDA_YML_FILE_TEMPLATE = (
61-
"name: integ_test_env\n"
62-
"channels:\n"
63-
" - defaults\n"
64-
"dependencies:\n"
65-
" - scipy=1.7.3\n"
66-
" - pip:\n"
67-
" - /sagemaker-{sagemaker_version}.tar.gz\n"
68-
"prefix: /opt/conda/bin/conda\n"
61+
AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE = (
62+
"FROM public.ecr.aws/docker/library/python:{py_version}-slim\n\n"
63+
'SHELL ["/bin/bash", "-c"]\n'
64+
"RUN apt-get update -y \
65+
&& apt-get install -y unzip curl\n\n"
66+
"RUN curl -L -O 'https://github.com/conda-forge/miniforge/releases/latest/download/Mambaforge-Linux-x86_64.sh' \
67+
&& bash Mambaforge-Linux-x86_64.sh -b -p '/opt/conda' \
68+
&& /opt/conda/bin/conda init bash\n\n"
69+
"ENV PATH $PATH:/opt/conda/bin\n"
70+
"COPY {source_archive} ./\n"
71+
"RUN mamba create -n auto_capture_client python={py_version} -y \
72+
&& mamba run -n auto_capture_client pip install '{source_archive}' awscli boto3\n"
73+
"COPY test_auto_capture.py .\n"
74+
"CMD [\"mamba\", \"run\", \"-n\", \"auto_capture_client\", \"python\", \"test_auto_capture.py\"]\n"
6975
)
7076

71-
CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL = (
77+
CONDA_YML_FILE_TEMPLATE = (
7278
"name: integ_test_env\n"
7379
"channels:\n"
7480
" - defaults\n"
7581
"dependencies:\n"
7682
" - scipy=1.7.3\n"
7783
" - pip:\n"
78-
" - sagemaker-2.132.1.dev0-py2.py3-none-any.whl\n"
84+
" - /sagemaker-{sagemaker_version}.tar.gz\n"
7985
"prefix: /opt/conda/bin/conda\n"
8086
)
8187

@@ -99,6 +105,12 @@ def dummy_container_with_conda(sagemaker_session):
99105
return ecr_uri
100106

101107

108+
@pytest.fixture(scope="package")
109+
def auto_capture_test_container(sagemaker_session):
110+
ecr_uri = _build_container_locally("3.10", AUTO_CAPTURE_CLIENT_DOCKER_TEMPLATE)
111+
return ecr_uri
112+
113+
102114
@pytest.fixture(scope="package")
103115
def conda_env_yml():
104116
"""Write conda yml file needed for tests"""
@@ -116,22 +128,6 @@ def conda_env_yml():
116128
os.remove(conda_yml_file_name)
117129

118130

119-
@pytest.fixture(scope="package")
120-
def conda_yml_file_sm_from_input_channel():
121-
"""Write conda yml file needed for tests"""
122-
123-
conda_yml_file_name = "conda_env_sm_from_input_channel.yml"
124-
conda_file_path = os.path.join(os.getcwd(), conda_yml_file_name)
125-
126-
with open(conda_file_path, "w") as yml_file:
127-
yml_file.writelines(CONDA_YML_FILE_WITH_SM_FROM_INPUT_CHANNEL)
128-
yield conda_file_path
129-
130-
# cleanup
131-
if os.path.isfile(conda_yml_file_name):
132-
os.remove(conda_yml_file_name)
133-
134-
135131
def _build_container(sagemaker_session, py_version, docker_templete):
136132
"""Build a dummy test container locally and push a container to an ecr repo"""
137133

@@ -178,6 +174,23 @@ def _build_container(sagemaker_session, py_version, docker_templete):
178174
return ecr_image
179175

180176

177+
def _build_container_locally(py_version, docker_templete):
178+
with _tmpdir() as tmpdir:
179+
print("building docker image locally in ", tmpdir)
180+
print("building source archive...")
181+
source_archive = _generate_sdk_tar_with_public_version(tmpdir)
182+
_move_auto_capture_test_file(tmpdir)
183+
with open(os.path.join(tmpdir, "Dockerfile"), "w") as file:
184+
file.writelines(
185+
docker_templete.format(py_version=py_version, source_archive=source_archive)
186+
)
187+
188+
docker_client = docker.from_env()
189+
190+
print("building docker image...")
191+
image, build_logs = docker_client.images.build(path=tmpdir, tag=REPO_NAME, rm=True)
192+
return image.id
193+
181194
def _is_repository_exists(ecr_client, repo_name):
182195
try:
183196
ecr_client.describe_repositories(repositoryNames=[repo_name])
@@ -212,3 +225,30 @@ def _generate_and_move_sagemaker_sdk_tar(destination_folder):
212225
shutil.copy2(source_path, destination_path)
213226

214227
return source_archive
228+
229+
230+
def _generate_sdk_tar_with_public_version(destination_folder):
231+
with open(os.path.join(os.getcwd(), "VERSION"), "r+") as version_file:
232+
dev_sagemaker_version = version_file.readline().strip()
233+
public_sagemaker_version = re.sub("1.dev0", "0", dev_sagemaker_version)
234+
version_file.seek(0)
235+
version_file.write(public_sagemaker_version)
236+
version_file.truncate()
237+
shutil.rmtree("dist")
238+
239+
source_archive = _generate_and_move_sagemaker_sdk_tar(destination_folder)
240+
241+
with open(os.path.join(os.getcwd(), "VERSION"), "r+") as version_file:
242+
version_file.seek(0)
243+
version_file.write(dev_sagemaker_version)
244+
version_file.truncate()
245+
shutil.rmtree("dist")
246+
247+
return source_archive
248+
249+
250+
def _move_auto_capture_test_file(destination_folder):
251+
test_file_name = "test_auto_capture.py"
252+
source_path = os.path.join(os.getcwd(), "tests", "integ", "sagemaker", "remote_function", test_file_name)
253+
destination_path = os.path.join(destination_folder, test_file_name)
254+
shutil.copy2(source_path, destination_path)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import pandas as pd
17+
18+
from sagemaker.remote_function import remote
19+
20+
21+
@remote(
22+
role="SageMakerRole",
23+
instance_type="ml.m5.xlarge",
24+
dependencies="auto_capture",
25+
)
26+
def multiply(dataframe: pd.DataFrame, factor: float):
27+
return dataframe * factor
28+
29+
30+
df = pd.DataFrame(
31+
{
32+
"A": [14, 4, 5, 4, 1],
33+
"B": [5, 2, 54, 3, 2],
34+
"C": [20, 20, 7, 3, 8],
35+
"D": [14, 3, 6, 2, 6],
36+
}
37+
)
38+
39+
if __name__ == "__main__":
40+
multiply(df, 10.0)

tests/integ/sagemaker/remote_function/test_decorator.py

+19
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import random
2121
import string
2222
import pandas as pd
23+
import subprocess
24+
import shlex
2325
from sagemaker.experiments.run import Run, load_run
2426
from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources
2527
from sagemaker.experiments.trial_component import _TrialComponent
@@ -596,3 +598,20 @@ def get_file_content(file_names):
596598
with pytest.raises(RuntimeEnvironmentError) as e:
597599
get_file_content(["test_file_1", "test_file_2", "test_file_3"])
598600
assert "line 2: bws: command not found" in str(e)
601+
602+
603+
def test_decorator_auto_capture(sagemaker_session, auto_capture_test_container):
604+
creds = sagemaker_session.boto_session.get_credentials()
605+
region = sagemaker_session.boto_session.region_name
606+
env = {
607+
"AWS_ACCESS_KEY_ID": str(creds.access_key),
608+
"AWS_SECRET_ACCESS_KEY": str(creds.secret_key),
609+
"AWS_SESSION_TOKEN": str(creds.token),
610+
}
611+
cmd = (f"docker run -e AWS_ACCESS_KEY_ID={env['AWS_ACCESS_KEY_ID']} "
612+
f"-e AWS_SECRET_ACCESS_KEY={env['AWS_SECRET_ACCESS_KEY']} "
613+
f"-e AWS_SESSION_TOKEN={env['AWS_SESSION_TOKEN']} "
614+
f"-e AWS_DEFAULT_REGION={region} "
615+
f"-it {auto_capture_test_container}")
616+
result = subprocess.check_output(shlex.split(cmd), stderr=subprocess.STDOUT).decode("utf-8")
617+
print(result)

0 commit comments

Comments
 (0)