Skip to content

Commit 01ea241

Browse files
qidewenwhenNamrata Madan
authored and
Namrata Madan
committed
change: Enable Experiment integ test on beta clients (aws#3590)
1 parent 842a7d9 commit 01ea241

File tree

5 files changed

+219
-40
lines changed

5 files changed

+219
-40
lines changed

tests/conftest.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botocore.config import Config
2323
from packaging.version import Version
2424

25-
from sagemaker import Session, image_uris, utils
25+
from sagemaker import Session, image_uris, utils, get_execution_role
2626
from sagemaker.local import LocalSession
2727
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession
2828

@@ -93,6 +93,7 @@ def pytest_addoption(parser):
9393
parser.addoption("--sagemaker-client-config", action="store", default=None)
9494
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
9595
parser.addoption("--boto-config", action="store", default=None)
96+
parser.addoption("--sagemaker-metrics-config", action="store", default=None)
9697

9798

9899
def pytest_configure(config):
@@ -115,6 +116,12 @@ def sagemaker_runtime_config(request):
115116
return json.loads(config) if config else None
116117

117118

119+
@pytest.fixture(scope="session")
120+
def sagemaker_metrics_config(request):
121+
config = request.config.getoption("--sagemaker-metrics-config")
122+
return json.loads(config) if config else None
123+
124+
118125
@pytest.fixture(scope="session")
119126
def boto_session(request):
120127
config = request.config.getoption("--boto-config")
@@ -135,7 +142,9 @@ def region(boto_session):
135142

136143

137144
@pytest.fixture(scope="session")
138-
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_session):
145+
def sagemaker_session(
146+
sagemaker_client_config, sagemaker_runtime_config, boto_session, sagemaker_metrics_config
147+
):
139148
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
140149
sagemaker_client = (
141150
boto_session.client("sagemaker", **sagemaker_client_config)
@@ -147,11 +156,17 @@ def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_se
147156
if sagemaker_runtime_config
148157
else None
149158
)
159+
metrics_client = (
160+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
161+
if sagemaker_metrics_config
162+
else None
163+
)
150164

151165
return Session(
152166
boto_session=boto_session,
153167
sagemaker_client=sagemaker_client,
154168
sagemaker_runtime_client=runtime_client,
169+
sagemaker_metrics_client=metrics_client,
155170
)
156171

157172

@@ -170,6 +185,11 @@ def local_pipeline_session(boto_session):
170185
return LocalPipelineSession(boto_session=boto_session)
171186

172187

188+
@pytest.fixture(scope="session")
189+
def execution_role(sagemaker_session):
190+
return get_execution_role(sagemaker_session)
191+
192+
173193
@pytest.fixture(scope="module")
174194
def custom_bucket_name(boto_session):
175195
region = boto_session.region_name

tests/data/experiment/inference.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13+
import json
1314
import logging
1415
import os
1516
import pickle as pkl
@@ -24,11 +25,32 @@
2425
sdk_file = f"{code_dir}/{sdk_name}"
2526
os.system(f"pip install {sdk_file}")
2627

28+
29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2733
from sagemaker.session import Session
2834
from sagemaker.experiments import load_run
2935

3036
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
31-
sagemaker_session = Session(boto_session=boto_session)
37+
38+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
39+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
40+
sagemaker_client = (
41+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
42+
)
43+
metrics_client = (
44+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
45+
if sagemaker_metrics_config
46+
else None
47+
)
48+
49+
sagemaker_session = Session(
50+
boto_session=boto_session,
51+
sagemaker_client=sagemaker_client,
52+
sagemaker_metrics_client=metrics_client,
53+
)
3254

3355

3456
def model_fn(model_dir):

tests/data/experiment/process_job_script_for_run_clz.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This script file runs on SageMaker processing job"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import boto3
@@ -25,8 +26,28 @@
2526
from sagemaker.experiments import load_run
2627

2728

29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2833
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
29-
sagemaker_session = Session(boto_session=boto_session)
34+
35+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
36+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
37+
sagemaker_client = (
38+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
39+
)
40+
metrics_client = (
41+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
42+
if sagemaker_metrics_config
43+
else None
44+
)
45+
46+
sagemaker_session = Session(
47+
boto_session=boto_session,
48+
sagemaker_client=sagemaker_client,
49+
sagemaker_metrics_client=metrics_client,
50+
)
3051

3152

3253
with load_run(sagemaker_session=sagemaker_session) as run:

tests/data/experiment/train_job_script_for_run_clz.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This script file runs on SageMaker training job"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import time
1819
import os
@@ -24,8 +25,29 @@
2425
from sagemaker import Session
2526
from sagemaker.experiments import load_run, Run
2627

28+
29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2733
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
28-
sagemaker_session = Session(boto_session=boto_session)
34+
35+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
36+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
37+
sagemaker_client = (
38+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
39+
)
40+
metrics_client = (
41+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
42+
if sagemaker_metrics_config
43+
else None
44+
)
45+
46+
sagemaker_session = Session(
47+
boto_session=boto_session,
48+
sagemaker_client=sagemaker_client,
49+
sagemaker_metrics_client=metrics_client,
50+
)
2951

3052
if os.environ["RUN_OPERATION"] == "init":
3153
logging.info("Initializing a Run")

0 commit comments

Comments
 (0)