Skip to content

change: Enable Experiment integ test on beta clients #3590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from botocore.config import Config
from packaging.version import Version

from sagemaker import Session, image_uris, utils
from sagemaker import Session, image_uris, utils, get_execution_role
from sagemaker.local import LocalSession
from sagemaker.workflow.pipeline_context import PipelineSession, LocalPipelineSession

Expand Down Expand Up @@ -91,6 +91,7 @@ def pytest_addoption(parser):
parser.addoption("--sagemaker-client-config", action="store", default=None)
parser.addoption("--sagemaker-runtime-config", action="store", default=None)
parser.addoption("--boto-config", action="store", default=None)
parser.addoption("--sagemaker-metrics-config", action="store", default=None)


def pytest_configure(config):
Expand All @@ -113,6 +114,12 @@ def sagemaker_runtime_config(request):
return json.loads(config) if config else None


@pytest.fixture(scope="session")
def sagemaker_metrics_config(request):
config = request.config.getoption("--sagemaker-metrics-config")
return json.loads(config) if config else None


@pytest.fixture(scope="session")
def boto_session(request):
config = request.config.getoption("--boto-config")
Expand All @@ -133,7 +140,9 @@ def region(boto_session):


@pytest.fixture(scope="session")
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_session):
def sagemaker_session(
sagemaker_client_config, sagemaker_runtime_config, boto_session, sagemaker_metrics_config
):
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
Expand All @@ -145,11 +154,17 @@ def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_se
if sagemaker_runtime_config
else None
)
metrics_client = (
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
if sagemaker_metrics_config
else None
)

return Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
sagemaker_metrics_client=metrics_client,
)


Expand All @@ -168,6 +183,11 @@ def local_pipeline_session(boto_session):
return LocalPipelineSession(boto_session=boto_session)


@pytest.fixture(scope="session")
def execution_role(sagemaker_session):
return get_execution_role(sagemaker_session)


@pytest.fixture(scope="module")
def custom_bucket_name(boto_session):
region = boto_session.region_name
Expand Down
24 changes: 23 additions & 1 deletion tests/data/experiment/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import json
import logging
import os
import pickle as pkl
Expand All @@ -24,11 +25,32 @@
sdk_file = f"{code_dir}/{sdk_name}"
os.system(f"pip install {sdk_file}")


def _get_client_config_in_dict(cfg_in_str) -> dict:
return json.loads(cfg_in_str) if cfg_in_str else None


from sagemaker.session import Session
from sagemaker.experiments import load_run

boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
sagemaker_session = Session(boto_session=boto_session)

sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
)
metrics_client = (
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
if sagemaker_metrics_config
else None
)

sagemaker_session = Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_metrics_client=metrics_client,
)


def model_fn(model_dir):
Expand Down
23 changes: 22 additions & 1 deletion tests/data/experiment/process_job_script_for_run_clz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This script file runs on SageMaker processing job"""
from __future__ import absolute_import

import json
import logging
import os
import boto3
Expand All @@ -25,8 +26,28 @@
from sagemaker.experiments import load_run


def _get_client_config_in_dict(cfg_in_str) -> dict:
return json.loads(cfg_in_str) if cfg_in_str else None


boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
sagemaker_session = Session(boto_session=boto_session)

sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
)
metrics_client = (
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
if sagemaker_metrics_config
else None
)

sagemaker_session = Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_metrics_client=metrics_client,
)


with load_run(sagemaker_session=sagemaker_session) as run:
Expand Down
24 changes: 23 additions & 1 deletion tests/data/experiment/train_job_script_for_run_clz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""This script file runs on SageMaker training job"""
from __future__ import absolute_import

import json
import logging
import time
import os
Expand All @@ -24,8 +25,29 @@
from sagemaker import Session
from sagemaker.experiments import load_run, Run


def _get_client_config_in_dict(cfg_in_str) -> dict:
return json.loads(cfg_in_str) if cfg_in_str else None


boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
sagemaker_session = Session(boto_session=boto_session)

sagemaker_client_config = _get_client_config_in_dict(os.environ.get("SM_CLIENT_CONFIG", None))
sagemaker_metrics_config = _get_client_config_in_dict(os.environ.get("SM_METRICS_CONFIG", None))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config) if sagemaker_client_config else None
)
metrics_client = (
boto_session.client("sagemaker-metrics", **sagemaker_metrics_config)
if sagemaker_metrics_config
else None
)

sagemaker_session = Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_metrics_client=metrics_client,
)

if os.environ["RUN_OPERATION"] == "init":
logging.info("Initializing a Run")
Expand Down
Loading