Skip to content

Commit 12dbd37

Browse files
committed
change: Enable Experiment integ test on beta clients
1 parent 083f95c commit 12dbd37

File tree

5 files changed

+180
-31
lines changed

5 files changed

+180
-31
lines changed

tests/conftest.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botocore.config import Config
2323
from packaging.version import Version
2424

25-
from sagemaker import Session, image_uris, utils
25+
from sagemaker import Session, image_uris, utils, get_execution_role
2626
from sagemaker.local import LocalSession
2727
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession
2828

@@ -91,6 +91,7 @@ def pytest_addoption(parser):
9191
parser.addoption("--sagemaker-client-config", action="store", default=None)
9292
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
9393
parser.addoption("--boto-config", action="store", default=None)
94+
parser.addoption("--sagemaker-metrics-config", action="store", default=None)
9495

9596

9697
def pytest_configure(config):
@@ -113,6 +114,12 @@ def sagemaker_runtime_config(request):
113114
return json.loads(config) if config else None
114115

115116

117+
@pytest.fixture(scope="session")
118+
def sagemaker_metrics_config(request):
119+
config = request.config.getoption("--sagemaker-metrics-config")
120+
return json.loads(config) if config else None
121+
122+
116123
@pytest.fixture(scope="session")
117124
def boto_session(request):
118125
config = request.config.getoption("--boto-config")
@@ -133,7 +140,9 @@ def region(boto_session):
133140

134141

135142
@pytest.fixture(scope="session")
136-
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_session):
143+
def sagemaker_session(
144+
sagemaker_client_config, sagemaker_runtime_config, boto_session, sagemaker_metrics_config
145+
):
137146
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
138147
sagemaker_client = (
139148
boto_session.client("sagemaker", **sagemaker_client_config)
@@ -145,11 +154,17 @@ def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_se
145154
if sagemaker_runtime_config
146155
else None
147156
)
157+
metrics_client = (
158+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
159+
if sagemaker_metrics_config
160+
else None
161+
)
148162

149163
return Session(
150164
boto_session=boto_session,
151165
sagemaker_client=sagemaker_client,
152166
sagemaker_runtime_client=runtime_client,
167+
sagemaker_metrics_client=metrics_client,
153168
)
154169

155170

@@ -168,6 +183,11 @@ def local_pipeline_session(boto_session):
168183
return LocalPipelineSession(boto_session=boto_session)
169184

170185

186+
@pytest.fixture(scope="session")
187+
def execution_role(sagemaker_session):
188+
return get_execution_role(sagemaker_session)
189+
190+
171191
@pytest.fixture(scope="module")
172192
def custom_bucket_name(boto_session):
173193
region = boto_session.region_name

tests/data/experiment/inference.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13+
import json
1314
import logging
1415
import os
1516
import pickle as pkl
@@ -24,11 +25,32 @@
2425
sdk_file = f"{code_dir}/{sdk_name}"
2526
os.system(f"pip install {sdk_file}")
2627

28+
29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2733
from sagemaker.session import Session
2834
from sagemaker.experiments import load_run
2935

3036
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
31-
sagemaker_session = Session(boto_session=boto_session)
37+
38+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
39+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
40+
sagemaker_client = (
41+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
42+
)
43+
metrics_client = (
44+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
45+
if sagemaker_metrics_config
46+
else None
47+
)
48+
49+
sagemaker_session = Session(
50+
boto_session=boto_session,
51+
sagemaker_client=sagemaker_client,
52+
sagemaker_metrics_client=metrics_client,
53+
)
3254

3355

3456
def model_fn(model_dir):

tests/data/experiment/process_job_script_for_run_clz.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This script file runs on SageMaker processing job"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import os
1819
import boto3
@@ -25,8 +26,28 @@
2526
from sagemaker.experiments import load_run
2627

2728

29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2833
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
29-
sagemaker_session = Session(boto_session=boto_session)
34+
35+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
36+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
37+
sagemaker_client = (
38+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
39+
)
40+
metrics_client = (
41+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
42+
if sagemaker_metrics_config
43+
else None
44+
)
45+
46+
sagemaker_session = Session(
47+
boto_session=boto_session,
48+
sagemaker_client=sagemaker_client,
49+
sagemaker_metrics_client=metrics_client,
50+
)
3051

3152

3253
with load_run(sagemaker_session=sagemaker_session) as run:

tests/data/experiment/train_job_script_for_run_clz.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""This script file runs on SageMaker training job"""
1414
from __future__ import absolute_import
1515

16+
import json
1617
import logging
1718
import time
1819
import os
@@ -24,8 +25,29 @@
2425
from sagemaker import Session
2526
from sagemaker.experiments import load_run, Run
2627

28+
29+
def _get_client_config_in_dict(cfg_in_str) -> dict:
30+
return json.loads(cfg_in_str) if cfg_in_str else None
31+
32+
2733
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
28-
sagemaker_session = Session(boto_session=boto_session)
34+
35+
sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
36+
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
37+
sagemaker_client = (
38+
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
39+
)
40+
metrics_client = (
41+
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
42+
if sagemaker_metrics_config
43+
else None
44+
)
45+
46+
sagemaker_session = Session(
47+
boto_session=boto_session,
48+
sagemaker_client=sagemaker_client,
49+
sagemaker_metrics_client=metrics_client,
50+
)
2951

3052
if os.environ["RUN_OPERATION"] == "init":
3153
logging.info("Initializing a Run")

0 commit comments

Comments
 (0)