Skip to content

Commit 9079f37

Browse files
authored
feature: Expose Experiment class publicly (aws#3794)
1 parent 38a3a2d commit 9079f37

File tree

14 files changed

+80
-72
lines changed

14 files changed

+80
-72
lines changed

doc/experiments/sagemaker.experiments.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ Run
1111

1212
.. automethod:: sagemaker.experiments.list_runs
1313

14+
Experiment
15+
-------------
16+
17+
.. autoclass:: sagemaker.experiments.Experiment
18+
:members:
19+
20+
Other
21+
-------------
22+
1423
.. autoclass:: sagemaker.experiments.SortByType
1524
:members:
1625
:undoc-members:

src/sagemaker/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.experiments.run import Run # noqa: F401
17+
from sagemaker.experiments.experiment import Experiment # noqa: F401
1718
from sagemaker.experiments.run import load_run # noqa: F401
1819
from sagemaker.experiments.run import list_runs # noqa: F401
1920
from sagemaker.experiments.run import SortOrderType # noqa: F401

src/sagemaker/experiments/experiment.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
from sagemaker.experiments.trial_component import _TrialComponent
2323

2424

25-
class _Experiment(_base_types.Record):
25+
class Experiment(_base_types.Record):
2626
"""An Amazon SageMaker experiment, which is a collection of related trials.
2727
28-
New experiments are created by calling `experiments.experiment._Experiment.create`.
29-
Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`.
28+
New experiments are created by calling `experiments.experiment.Experiment.create`.
29+
Existing experiments can be reloaded by calling `experiments.experiment.Experiment.load`.
3030
3131
Attributes:
3232
experiment_name (str): The name of the experiment. The name must be unique
@@ -73,7 +73,7 @@ def delete(self):
7373

7474
@classmethod
7575
def load(cls, experiment_name, sagemaker_session=None):
76-
"""Load an existing experiment and return an `_Experiment` object representing it.
76+
"""Load an existing experiment and return an `Experiment` object representing it.
7777
7878
Args:
7979
experiment_name: (str): Name of the experiment
@@ -83,7 +83,7 @@ def load(cls, experiment_name, sagemaker_session=None):
8383
default AWS configuration chain.
8484
8585
Returns:
86-
experiments.experiment._Experiment: A SageMaker `_Experiment` object
86+
experiments.experiment.Experiment: A SageMaker `Experiment` object
8787
"""
8888
return cls._construct(
8989
cls._boto_load_method,
@@ -100,7 +100,7 @@ def create(
100100
tags=None,
101101
sagemaker_session=None,
102102
):
103-
"""Create a new experiment in SageMaker and return an `_Experiment` object.
103+
"""Create a new experiment in SageMaker and return an `Experiment` object.
104104
105105
Args:
106106
experiment_name: (str): Name of the experiment. Must be unique. Required.
@@ -115,7 +115,7 @@ def create(
115115
(default: None).
116116
117117
Returns:
118-
experiments.experiment._Experiment: A SageMaker `_Experiment` object
118+
experiments.experiment.Experiment: A SageMaker `Experiment` object
119119
"""
120120
return cls._construct(
121121
cls._boto_create_method,
@@ -154,10 +154,10 @@ def _load_or_create(
154154
exist and a new experiment has to be created.
155155
156156
Returns:
157-
experiments.experiment._Experiment: A SageMaker `_Experiment` object
157+
experiments.experiment.Experiment: A SageMaker `Experiment` object
158158
"""
159159
try:
160-
experiment = _Experiment.create(
160+
experiment = Experiment.create(
161161
experiment_name=experiment_name,
162162
display_name=display_name,
163163
description=description,
@@ -170,7 +170,7 @@ def _load_or_create(
170170
if not (error_code == "ValidationException" and "already exists" in error_message):
171171
raise ce
172172
# already exists
173-
experiment = _Experiment.load(experiment_name, sagemaker_session)
173+
experiment = Experiment.load(experiment_name, sagemaker_session)
174174
return experiment
175175

176176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

src/sagemaker/experiments/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from sagemaker.experiments._environment import _RunEnvironment
3434
from sagemaker.experiments._run_context import _RunContext
35-
from sagemaker.experiments.experiment import _Experiment
35+
from sagemaker.experiments.experiment import Experiment
3636
from sagemaker.experiments._metrics import _MetricsManager
3737
from sagemaker.experiments.trial import _Trial
3838
from sagemaker.experiments.trial_component import _TrialComponent
@@ -166,7 +166,7 @@ def __init__(
166166
)
167167
self.run_group_name = Run._generate_trial_name(self.experiment_name)
168168

169-
self._experiment = _Experiment._load_or_create(
169+
self._experiment = Experiment._load_or_create(
170170
experiment_name=self.experiment_name,
171171
display_name=experiment_display_name,
172172
tags=tags,

tests/integ/sagemaker/experiments/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def run_obj(sagemaker_session):
4646
yield run
4747
time.sleep(0.5)
4848
finally:
49-
exp = experiment._Experiment.load(
49+
exp = experiment.Experiment.load(
5050
experiment_name=run.experiment_name, sagemaker_session=sagemaker_session
5151
)
5252
exp._delete_all(action="--force")
@@ -71,7 +71,7 @@ def experiment_obj(sagemaker_session):
7171
description = "{}-{}".format("description", str(uuid.uuid4()))
7272
boto3.set_stream_logger("", logging.INFO)
7373
experiment_name = name()
74-
experiment_obj = experiment._Experiment.create(
74+
experiment_obj = experiment.Experiment.create(
7575
experiment_name=experiment_name,
7676
description=description,
7777
sagemaker_session=sagemaker_session,

tests/integ/sagemaker/experiments/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextlib import contextmanager
1616

1717
from sagemaker import utils
18-
from sagemaker.experiments.experiment import _Experiment
18+
from sagemaker.experiments.experiment import Experiment
1919

2020
EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ"
2121

@@ -38,5 +38,5 @@ def cleanup_exp_resources(exp_names, sagemaker_session):
3838
yield
3939
finally:
4040
for exp_name in exp_names:
41-
exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
41+
exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
4242
exp._delete_all(action="--force")

tests/integ/sagemaker/experiments/test_experiment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_save(experiment_obj):
4040

4141

4242
def test_save_load(experiment_obj, sagemaker_session):
43-
experiment_obj_two = experiment._Experiment.load(
43+
experiment_obj_two = experiment.Experiment.load(
4444
experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
4545
)
4646
assert experiment_obj.experiment_name == experiment_obj_two.experiment_name
@@ -49,7 +49,7 @@ def test_save_load(experiment_obj, sagemaker_session):
4949
experiment_obj.description = name()
5050
experiment_obj.display_name = name()
5151
experiment_obj.save()
52-
experiment_obj_three = experiment._Experiment.load(
52+
experiment_obj_three = experiment.Experiment.load(
5353
experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
5454
)
5555
assert experiment_obj.description == experiment_obj_three.description

tests/unit/sagemaker/experiments/conftest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import Session
21-
from sagemaker.experiments.experiment import _Experiment
21+
from sagemaker.experiments.experiment import Experiment
2222
from sagemaker.experiments.run import RUN_NAME_BASE
2323
from sagemaker.experiments import Run
2424
from tests.unit.sagemaker.experiments.helpers import (
@@ -57,9 +57,9 @@ def run_obj(sagemaker_session):
5757
client.update_trial_component.return_value = {}
5858
client.associate_trial_component.return_value = {}
5959
with patch(
60-
"sagemaker.experiments.run._Experiment._load_or_create",
60+
"sagemaker.experiments.run.Experiment._load_or_create",
6161
MagicMock(
62-
return_value=_Experiment(
62+
return_value=Experiment(
6363
experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session
6464
)
6565
),

tests/unit/sagemaker/experiments/test_experiment.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def datetime_obj():
3232
def test_load(sagemaker_session):
3333
client = sagemaker_session.sagemaker_client
3434
client.describe_experiment.return_value = {"Description": "description-value"}
35-
experiment_obj = experiment._Experiment.load(
35+
experiment_obj = experiment.Experiment.load(
3636
experiment_name="name-value", sagemaker_session=sagemaker_session
3737
)
3838
assert experiment_obj.experiment_name == "name-value"
@@ -44,7 +44,7 @@ def test_load(sagemaker_session):
4444
def test_create(sagemaker_session):
4545
client = sagemaker_session.sagemaker_client
4646
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
47-
experiment_obj = experiment._Experiment.create(
47+
experiment_obj = experiment.Experiment.create(
4848
experiment_name="name-value", sagemaker_session=sagemaker_session
4949
)
5050
assert experiment_obj.experiment_name == "name-value"
@@ -55,7 +55,7 @@ def test_create_with_tags(sagemaker_session):
5555
client = sagemaker_session.sagemaker_client
5656
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
5757
tags = [{"Key": "foo", "Value": "bar"}]
58-
experiment_obj = experiment._Experiment.create(
58+
experiment_obj = experiment.Experiment.create(
5959
experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags
6060
)
6161
assert experiment_obj.experiment_name == "name-value"
@@ -64,22 +64,22 @@ def test_create_with_tags(sagemaker_session):
6464

6565
def test_save(sagemaker_session):
6666
client = sagemaker_session.sagemaker_client
67-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
67+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
6868
client.update_experiment.return_value = {}
6969
obj.save()
7070
client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar")
7171

7272

7373
def test_delete(sagemaker_session):
7474
client = sagemaker_session.sagemaker_client
75-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
75+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
7676
client.delete_experiment.return_value = {}
7777
obj.delete()
7878
client.delete_experiment.assert_called_with(ExperimentName="foo")
7979

8080

81-
@patch("sagemaker.experiments.experiment._Experiment.load")
82-
@patch("sagemaker.experiments.experiment._Experiment.create")
81+
@patch("sagemaker.experiments.experiment.Experiment.load")
82+
@patch("sagemaker.experiments.experiment.Experiment.create")
8383
def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
8484
exp_name = "exp_name"
8585
exists_error = botocore.exceptions.ClientError(
@@ -92,7 +92,7 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
9292
operation_name="foo",
9393
)
9494
mock_create.side_effect = exists_error
95-
experiment._Experiment._load_or_create(
95+
experiment.Experiment._load_or_create(
9696
experiment_name=exp_name, sagemaker_session=sagemaker_session
9797
)
9898
mock_create.assert_called_once_with(
@@ -105,12 +105,12 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
105105
mock_load.assert_called_once_with(exp_name, sagemaker_session)
106106

107107

108-
@patch("sagemaker.experiments.experiment._Experiment.load")
109-
@patch("sagemaker.experiments.experiment._Experiment.create")
108+
@patch("sagemaker.experiments.experiment.Experiment.load")
109+
@patch("sagemaker.experiments.experiment.Experiment.create")
110110
def test_load_or_create_when_not_exist(mock_create, mock_load):
111111
sagemaker_session = Session()
112112
exp_name = "exp_name"
113-
experiment._Experiment._load_or_create(
113+
experiment.Experiment._load_or_create(
114114
experiment_name=exp_name, sagemaker_session=sagemaker_session
115115
)
116116
mock_create.assert_called_once_with(
@@ -125,12 +125,12 @@ def test_load_or_create_when_not_exist(mock_create, mock_load):
125125

126126
def test_list_trials_empty(sagemaker_session):
127127
sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []}
128-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
128+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
129129
assert list(experiment_obj.list_trials()) == []
130130

131131

132132
def test_list_trials_single(sagemaker_session, datetime_obj):
133-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
133+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
134134
sagemaker_session.sagemaker_client.list_trials.return_value = {
135135
"TrialSummaries": [
136136
{"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}
@@ -143,7 +143,7 @@ def test_list_trials_single(sagemaker_session, datetime_obj):
143143

144144

145145
def test_list_trials_two_values(sagemaker_session, datetime_obj):
146-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
146+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
147147
sagemaker_session.sagemaker_client.list_trials.return_value = {
148148
"TrialSummaries": [
149149
{"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
@@ -162,7 +162,7 @@ def test_list_trials_two_values(sagemaker_session, datetime_obj):
162162

163163

164164
def test_next_token(sagemaker_session, datetime_obj):
165-
experiment_obj = experiment._Experiment(sagemaker_session)
165+
experiment_obj = experiment.Experiment(sagemaker_session)
166166
client = sagemaker_session.sagemaker_client
167167
client.list_trials.side_effect = [
168168
{
@@ -211,7 +211,7 @@ def test_list_trials_call_args(sagemaker_session):
211211
client = sagemaker_session.sagemaker_client
212212
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
213213
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
214-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
214+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
215215
client.list_trials.return_value = {}
216216
assert [] == list(
217217
experiment_obj.list_trials(created_after=created_after, created_before=created_before)
@@ -220,15 +220,15 @@ def test_list_trials_call_args(sagemaker_session):
220220

221221

222222
def test_delete_all_with_incorrect_action_name(sagemaker_session):
223-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
223+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
224224
with pytest.raises(ValueError) as err:
225225
obj._delete_all(action="abc")
226226

227227
assert "Must confirm with string '--force'" in str(err)
228228

229229

230230
def test_delete_all(sagemaker_session):
231-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
231+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
232232
client = sagemaker_session.sagemaker_client
233233
client.list_trials.return_value = {
234234
"TrialSummaries": [
@@ -310,7 +310,7 @@ def test_delete_all(sagemaker_session):
310310

311311

312312
def test_delete_all_fail(sagemaker_session):
313-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
313+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
314314
sagemaker_session.sagemaker_client.list_trials.side_effect = Exception
315315
with pytest.raises(Exception) as e:
316316
obj._delete_all(action="--force")

0 commit comments

Comments
 (0)