Skip to content

expose Experiment class publicly #3794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/experiments/sagemaker.experiments.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ Run

.. automethod:: sagemaker.experiments.list_runs

Experiment
-------------

.. autoclass:: sagemaker.experiments.Experiment
:members:

Other
-------------

.. autoclass:: sagemaker.experiments.SortByType
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

from sagemaker.experiments.run import Run # noqa: F401
from sagemaker.experiments.experiment import Experiment # noqa: F401
from sagemaker.experiments.run import load_run # noqa: F401
from sagemaker.experiments.run import list_runs # noqa: F401
from sagemaker.experiments.run import SortOrderType # noqa: F401
Expand Down
20 changes: 10 additions & 10 deletions src/sagemaker/experiments/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from sagemaker.experiments.trial_component import _TrialComponent


class _Experiment(_base_types.Record):
class Experiment(_base_types.Record):
"""An Amazon SageMaker experiment, which is a collection of related trials.

New experiments are created by calling `experiments.experiment._Experiment.create`.
Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`.
New experiments are created by calling `experiments.experiment.Experiment.create`.
Existing experiments can be reloaded by calling `experiments.experiment.Experiment.load`.

Attributes:
experiment_name (str): The name of the experiment. The name must be unique
Expand Down Expand Up @@ -73,7 +73,7 @@ def delete(self):

@classmethod
def load(cls, experiment_name, sagemaker_session=None):
"""Load an existing experiment and return an `_Experiment` object representing it.
"""Load an existing experiment and return an `Experiment` object representing it.

Args:
experiment_name: (str): Name of the experiment
Expand All @@ -83,7 +83,7 @@ def load(cls, experiment_name, sagemaker_session=None):
default AWS configuration chain.

Returns:
experiments.experiment._Experiment: A SageMaker `_Experiment` object
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
return cls._construct(
cls._boto_load_method,
Expand All @@ -100,7 +100,7 @@ def create(
tags=None,
sagemaker_session=None,
):
"""Create a new experiment in SageMaker and return an `_Experiment` object.
"""Create a new experiment in SageMaker and return an `Experiment` object.

Args:
experiment_name: (str): Name of the experiment. Must be unique. Required.
Expand All @@ -115,7 +115,7 @@ def create(
(default: None).

Returns:
experiments.experiment._Experiment: A SageMaker `_Experiment` object
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
return cls._construct(
cls._boto_create_method,
Expand Down Expand Up @@ -154,10 +154,10 @@ def _load_or_create(
exist and a new experiment has to be created.

Returns:
experiments.experiment._Experiment: A SageMaker `_Experiment` object
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
try:
experiment = _Experiment.create(
experiment = Experiment.create(
experiment_name=experiment_name,
display_name=display_name,
description=description,
Expand All @@ -170,7 +170,7 @@ def _load_or_create(
if not (error_code == "ValidationException" and "already exists" in error_message):
raise ce
# already exists
experiment = _Experiment.load(experiment_name, sagemaker_session)
experiment = Experiment.load(experiment_name, sagemaker_session)
return experiment

def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from sagemaker.experiments._environment import _RunEnvironment
from sagemaker.experiments._run_context import _RunContext
from sagemaker.experiments.experiment import _Experiment
from sagemaker.experiments.experiment import Experiment
from sagemaker.experiments._metrics import _MetricsManager
from sagemaker.experiments.trial import _Trial
from sagemaker.experiments.trial_component import _TrialComponent
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(
)
self.run_group_name = Run._generate_trial_name(self.experiment_name)

self._experiment = _Experiment._load_or_create(
self._experiment = Experiment._load_or_create(
experiment_name=self.experiment_name,
display_name=experiment_display_name,
tags=tags,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def run_obj(sagemaker_session):
yield run
time.sleep(0.5)
finally:
exp = experiment._Experiment.load(
exp = experiment.Experiment.load(
experiment_name=run.experiment_name, sagemaker_session=sagemaker_session
)
exp._delete_all(action="--force")
Expand All @@ -71,7 +71,7 @@ def experiment_obj(sagemaker_session):
description = "{}-{}".format("description", str(uuid.uuid4()))
boto3.set_stream_logger("", logging.INFO)
experiment_name = name()
experiment_obj = experiment._Experiment.create(
experiment_obj = experiment.Experiment.create(
experiment_name=experiment_name,
description=description,
sagemaker_session=sagemaker_session,
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/experiments/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from contextlib import contextmanager

from sagemaker import utils
from sagemaker.experiments.experiment import _Experiment
from sagemaker.experiments.experiment import Experiment

EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ"

Expand All @@ -38,5 +38,5 @@ def cleanup_exp_resources(exp_names, sagemaker_session):
yield
finally:
for exp_name in exp_names:
exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
exp._delete_all(action="--force")
4 changes: 2 additions & 2 deletions tests/integ/sagemaker/experiments/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_save(experiment_obj):


def test_save_load(experiment_obj, sagemaker_session):
experiment_obj_two = experiment._Experiment.load(
experiment_obj_two = experiment.Experiment.load(
experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
)
assert experiment_obj.experiment_name == experiment_obj_two.experiment_name
Expand All @@ -49,7 +49,7 @@ def test_save_load(experiment_obj, sagemaker_session):
experiment_obj.description = name()
experiment_obj.display_name = name()
experiment_obj.save()
experiment_obj_three = experiment._Experiment.load(
experiment_obj_three = experiment.Experiment.load(
experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
)
assert experiment_obj.description == experiment_obj_three.description
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/sagemaker/experiments/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytest

from sagemaker import Session
from sagemaker.experiments.experiment import _Experiment
from sagemaker.experiments.experiment import Experiment
from sagemaker.experiments.run import RUN_NAME_BASE
from sagemaker.experiments import Run
from tests.unit.sagemaker.experiments.helpers import (
Expand Down Expand Up @@ -57,9 +57,9 @@ def run_obj(sagemaker_session):
client.update_trial_component.return_value = {}
client.associate_trial_component.return_value = {}
with patch(
"sagemaker.experiments.run._Experiment._load_or_create",
"sagemaker.experiments.run.Experiment._load_or_create",
MagicMock(
return_value=_Experiment(
return_value=Experiment(
experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session
)
),
Expand Down
38 changes: 19 additions & 19 deletions tests/unit/sagemaker/experiments/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def datetime_obj():
def test_load(sagemaker_session):
client = sagemaker_session.sagemaker_client
client.describe_experiment.return_value = {"Description": "description-value"}
experiment_obj = experiment._Experiment.load(
experiment_obj = experiment.Experiment.load(
experiment_name="name-value", sagemaker_session=sagemaker_session
)
assert experiment_obj.experiment_name == "name-value"
Expand All @@ -44,7 +44,7 @@ def test_load(sagemaker_session):
def test_create(sagemaker_session):
client = sagemaker_session.sagemaker_client
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
experiment_obj = experiment._Experiment.create(
experiment_obj = experiment.Experiment.create(
experiment_name="name-value", sagemaker_session=sagemaker_session
)
assert experiment_obj.experiment_name == "name-value"
Expand All @@ -55,7 +55,7 @@ def test_create_with_tags(sagemaker_session):
client = sagemaker_session.sagemaker_client
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
tags = [{"Key": "foo", "Value": "bar"}]
experiment_obj = experiment._Experiment.create(
experiment_obj = experiment.Experiment.create(
experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags
)
assert experiment_obj.experiment_name == "name-value"
Expand All @@ -64,22 +64,22 @@ def test_create_with_tags(sagemaker_session):

def test_save(sagemaker_session):
client = sagemaker_session.sagemaker_client
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
client.update_experiment.return_value = {}
obj.save()
client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar")


def test_delete(sagemaker_session):
client = sagemaker_session.sagemaker_client
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
client.delete_experiment.return_value = {}
obj.delete()
client.delete_experiment.assert_called_with(ExperimentName="foo")


@patch("sagemaker.experiments.experiment._Experiment.load")
@patch("sagemaker.experiments.experiment._Experiment.create")
@patch("sagemaker.experiments.experiment.Experiment.load")
@patch("sagemaker.experiments.experiment.Experiment.create")
def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
exp_name = "exp_name"
exists_error = botocore.exceptions.ClientError(
Expand All @@ -92,7 +92,7 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
operation_name="foo",
)
mock_create.side_effect = exists_error
experiment._Experiment._load_or_create(
experiment.Experiment._load_or_create(
experiment_name=exp_name, sagemaker_session=sagemaker_session
)
mock_create.assert_called_once_with(
Expand All @@ -105,12 +105,12 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
mock_load.assert_called_once_with(exp_name, sagemaker_session)


@patch("sagemaker.experiments.experiment._Experiment.load")
@patch("sagemaker.experiments.experiment._Experiment.create")
@patch("sagemaker.experiments.experiment.Experiment.load")
@patch("sagemaker.experiments.experiment.Experiment.create")
def test_load_or_create_when_not_exist(mock_create, mock_load):
sagemaker_session = Session()
exp_name = "exp_name"
experiment._Experiment._load_or_create(
experiment.Experiment._load_or_create(
experiment_name=exp_name, sagemaker_session=sagemaker_session
)
mock_create.assert_called_once_with(
Expand All @@ -125,12 +125,12 @@ def test_load_or_create_when_not_exist(mock_create, mock_load):

def test_list_trials_empty(sagemaker_session):
sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []}
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
assert list(experiment_obj.list_trials()) == []


def test_list_trials_single(sagemaker_session, datetime_obj):
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
sagemaker_session.sagemaker_client.list_trials.return_value = {
"TrialSummaries": [
{"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}
Expand All @@ -143,7 +143,7 @@ def test_list_trials_single(sagemaker_session, datetime_obj):


def test_list_trials_two_values(sagemaker_session, datetime_obj):
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
sagemaker_session.sagemaker_client.list_trials.return_value = {
"TrialSummaries": [
{"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
Expand All @@ -162,7 +162,7 @@ def test_list_trials_two_values(sagemaker_session, datetime_obj):


def test_next_token(sagemaker_session, datetime_obj):
experiment_obj = experiment._Experiment(sagemaker_session)
experiment_obj = experiment.Experiment(sagemaker_session)
client = sagemaker_session.sagemaker_client
client.list_trials.side_effect = [
{
Expand Down Expand Up @@ -211,7 +211,7 @@ def test_list_trials_call_args(sagemaker_session):
client = sagemaker_session.sagemaker_client
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
client.list_trials.return_value = {}
assert [] == list(
experiment_obj.list_trials(created_after=created_after, created_before=created_before)
Expand All @@ -220,15 +220,15 @@ def test_list_trials_call_args(sagemaker_session):


def test_delete_all_with_incorrect_action_name(sagemaker_session):
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
with pytest.raises(ValueError) as err:
obj._delete_all(action="abc")

assert "Must confirm with string '--force'" in str(err)


def test_delete_all(sagemaker_session):
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
client = sagemaker_session.sagemaker_client
client.list_trials.return_value = {
"TrialSummaries": [
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_delete_all(sagemaker_session):


def test_delete_all_fail(sagemaker_session):
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
sagemaker_session.sagemaker_client.list_trials.side_effect = Exception
with pytest.raises(Exception) as e:
obj._delete_all(action="--force")
Expand Down
Loading