Skip to content

Commit dd9c631

Browse files
committed
feature: expose Experiment class publicly
1 parent 162e922 commit dd9c631

File tree

9 files changed

+58
-48
lines changed

9 files changed

+58
-48
lines changed

doc/experiments/sagemaker.experiments.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,19 @@ Run
1111

1212
.. automethod:: sagemaker.experiments.list_runs
1313

14+
Experiment
15+
-------------
16+
17+
.. autoclass:: sagemaker.experiments.Experiment
18+
:members:
19+
20+
Other
21+
-------------
22+
1423
.. autoclass:: sagemaker.experiments.SortByType
1524
:members:
1625
:undoc-members:
1726

1827
.. autoclass:: sagemaker.experiments.SortOrderType
1928
:members:
20-
:undoc-members:
29+
:undoc-members:

src/sagemaker/experiments/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.experiments.run import Run # noqa: F401
17+
from sagemaker.experiments.experiment import Experiment # noqa: F401
1718
from sagemaker.experiments.run import load_run # noqa: F401
1819
from sagemaker.experiments.run import list_runs # noqa: F401
1920
from sagemaker.experiments.run import SortOrderType # noqa: F401

src/sagemaker/experiments/experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from sagemaker.experiments.trial_component import _TrialComponent
2323

2424

25-
class _Experiment(_base_types.Record):
25+
class Experiment(_base_types.Record):
2626
"""An Amazon SageMaker experiment, which is a collection of related trials.
2727
2828
New experiments are created by calling `experiments.experiment._Experiment.create`.
@@ -157,7 +157,7 @@ def _load_or_create(
157157
experiments.experiment._Experiment: A SageMaker `_Experiment` object
158158
"""
159159
try:
160-
experiment = _Experiment.create(
160+
experiment = Experiment.create(
161161
experiment_name=experiment_name,
162162
display_name=display_name,
163163
description=description,
@@ -170,7 +170,7 @@ def _load_or_create(
170170
if not (error_code == "ValidationException" and "already exists" in error_message):
171171
raise ce
172172
# already exists
173-
experiment = _Experiment.load(experiment_name, sagemaker_session)
173+
experiment = Experiment.load(experiment_name, sagemaker_session)
174174
return experiment
175175

176176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

src/sagemaker/experiments/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from sagemaker.experiments._environment import _RunEnvironment
3434
from sagemaker.experiments._run_context import _RunContext
35-
from sagemaker.experiments.experiment import _Experiment
35+
from sagemaker.experiments.experiment import Experiment
3636
from sagemaker.experiments._metrics import _MetricsManager
3737
from sagemaker.experiments.trial import _Trial
3838
from sagemaker.experiments.trial_component import _TrialComponent
@@ -166,7 +166,7 @@ def __init__(
166166
)
167167
self.run_group_name = Run._generate_trial_name(self.experiment_name)
168168

169-
self._experiment = _Experiment._load_or_create(
169+
self._experiment = Experiment._load_or_create(
170170
experiment_name=self.experiment_name,
171171
display_name=experiment_display_name,
172172
tags=tags,

tests/integ/sagemaker/experiments/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextlib import contextmanager
1616

1717
from sagemaker import utils
18-
from sagemaker.experiments.experiment import _Experiment
18+
from sagemaker.experiments.experiment import Experiment
1919

2020
EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ"
2121

@@ -38,5 +38,5 @@ def cleanup_exp_resources(exp_names, sagemaker_session):
3838
yield
3939
finally:
4040
for exp_name in exp_names:
41-
exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
41+
exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
4242
exp._delete_all(action="--force")

tests/unit/sagemaker/experiments/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pytest
1919

2020
from sagemaker import Session
21-
from sagemaker.experiments.experiment import _Experiment
21+
from sagemaker.experiments.experiment import Experiment
2222
from sagemaker.experiments.run import RUN_NAME_BASE
2323
from sagemaker.experiments import Run
2424
from tests.unit.sagemaker.experiments.helpers import (

tests/unit/sagemaker/experiments/test_experiment.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def datetime_obj():
3232
def test_load(sagemaker_session):
3333
client = sagemaker_session.sagemaker_client
3434
client.describe_experiment.return_value = {"Description": "description-value"}
35-
experiment_obj = experiment._Experiment.load(
35+
experiment_obj = experiment.Experiment.load(
3636
experiment_name="name-value", sagemaker_session=sagemaker_session
3737
)
3838
assert experiment_obj.experiment_name == "name-value"
@@ -44,7 +44,7 @@ def test_load(sagemaker_session):
4444
def test_create(sagemaker_session):
4545
client = sagemaker_session.sagemaker_client
4646
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
47-
experiment_obj = experiment._Experiment.create(
47+
experiment_obj = experiment.Experiment.create(
4848
experiment_name="name-value", sagemaker_session=sagemaker_session
4949
)
5050
assert experiment_obj.experiment_name == "name-value"
@@ -55,7 +55,7 @@ def test_create_with_tags(sagemaker_session):
5555
client = sagemaker_session.sagemaker_client
5656
client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
5757
tags = [{"Key": "foo", "Value": "bar"}]
58-
experiment_obj = experiment._Experiment.create(
58+
experiment_obj = experiment.Experiment.create(
5959
experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags
6060
)
6161
assert experiment_obj.experiment_name == "name-value"
@@ -64,22 +64,22 @@ def test_create_with_tags(sagemaker_session):
6464

6565
def test_save(sagemaker_session):
6666
client = sagemaker_session.sagemaker_client
67-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
67+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
6868
client.update_experiment.return_value = {}
6969
obj.save()
7070
client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar")
7171

7272

7373
def test_delete(sagemaker_session):
7474
client = sagemaker_session.sagemaker_client
75-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
75+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
7676
client.delete_experiment.return_value = {}
7777
obj.delete()
7878
client.delete_experiment.assert_called_with(ExperimentName="foo")
7979

8080

81-
@patch("sagemaker.experiments.experiment._Experiment.load")
82-
@patch("sagemaker.experiments.experiment._Experiment.create")
81+
@patch("sagemaker.experiments.experiment.Experiment.load")
82+
@patch("sagemaker.experiments.experiment.Experiment.create")
8383
def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
8484
exp_name = "exp_name"
8585
exists_error = botocore.exceptions.ClientError(
@@ -92,7 +92,7 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
9292
operation_name="foo",
9393
)
9494
mock_create.side_effect = exists_error
95-
experiment._Experiment._load_or_create(
95+
experiment.Experiment._load_or_create(
9696
experiment_name=exp_name, sagemaker_session=sagemaker_session
9797
)
9898
mock_create.assert_called_once_with(
@@ -105,12 +105,12 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session):
105105
mock_load.assert_called_once_with(exp_name, sagemaker_session)
106106

107107

108-
@patch("sagemaker.experiments.experiment._Experiment.load")
109-
@patch("sagemaker.experiments.experiment._Experiment.create")
108+
@patch("sagemaker.experiments.experiment.Experiment.load")
109+
@patch("sagemaker.experiments.experiment.Experiment.create")
110110
def test_load_or_create_when_not_exist(mock_create, mock_load):
111111
sagemaker_session = Session()
112112
exp_name = "exp_name"
113-
experiment._Experiment._load_or_create(
113+
experiment.Experiment._load_or_create(
114114
experiment_name=exp_name, sagemaker_session=sagemaker_session
115115
)
116116
mock_create.assert_called_once_with(
@@ -125,12 +125,12 @@ def test_load_or_create_when_not_exist(mock_create, mock_load):
125125

126126
def test_list_trials_empty(sagemaker_session):
127127
sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []}
128-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
128+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
129129
assert list(experiment_obj.list_trials()) == []
130130

131131

132132
def test_list_trials_single(sagemaker_session, datetime_obj):
133-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
133+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
134134
sagemaker_session.sagemaker_client.list_trials.return_value = {
135135
"TrialSummaries": [
136136
{"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}
@@ -143,7 +143,7 @@ def test_list_trials_single(sagemaker_session, datetime_obj):
143143

144144

145145
def test_list_trials_two_values(sagemaker_session, datetime_obj):
146-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
146+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
147147
sagemaker_session.sagemaker_client.list_trials.return_value = {
148148
"TrialSummaries": [
149149
{"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
@@ -162,7 +162,7 @@ def test_list_trials_two_values(sagemaker_session, datetime_obj):
162162

163163

164164
def test_next_token(sagemaker_session, datetime_obj):
165-
experiment_obj = experiment._Experiment(sagemaker_session)
165+
experiment_obj = experiment.Experiment(sagemaker_session)
166166
client = sagemaker_session.sagemaker_client
167167
client.list_trials.side_effect = [
168168
{
@@ -211,7 +211,7 @@ def test_list_trials_call_args(sagemaker_session):
211211
client = sagemaker_session.sagemaker_client
212212
created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
213213
created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
214-
experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
214+
experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session)
215215
client.list_trials.return_value = {}
216216
assert [] == list(
217217
experiment_obj.list_trials(created_after=created_after, created_before=created_before)
@@ -220,15 +220,15 @@ def test_list_trials_call_args(sagemaker_session):
220220

221221

222222
def test_delete_all_with_incorrect_action_name(sagemaker_session):
223-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
223+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
224224
with pytest.raises(ValueError) as err:
225225
obj._delete_all(action="abc")
226226

227227
assert "Must confirm with string '--force'" in str(err)
228228

229229

230230
def test_delete_all(sagemaker_session):
231-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
231+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
232232
client = sagemaker_session.sagemaker_client
233233
client.list_trials.return_value = {
234234
"TrialSummaries": [
@@ -310,7 +310,7 @@ def test_delete_all(sagemaker_session):
310310

311311

312312
def test_delete_all_fail(sagemaker_session):
313-
obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
313+
obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar")
314314
sagemaker_session.sagemaker_client.list_trials.side_effect = Exception
315315
with pytest.raises(Exception) as e:
316316
obj._delete_all(action="--force")

tests/unit/sagemaker/experiments/test_run.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_TrialComponentStatusType,
2929
TrialComponentSearchResult,
3030
)
31-
from sagemaker.experiments.experiment import _Experiment
31+
from sagemaker.experiments.experiment import Experiment
3232
from sagemaker.experiments.run import (
3333
TRIAL_NAME_TEMPLATE,
3434
MAX_RUN_TC_ARTIFACTS_LEN,
@@ -52,8 +52,8 @@
5252

5353

5454
@patch(
55-
"sagemaker.experiments.run._Experiment._load_or_create",
56-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
55+
"sagemaker.experiments.run.Experiment._load_or_create",
56+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
5757
)
5858
@patch(
5959
"sagemaker.experiments.run._Trial._load_or_create",
@@ -122,8 +122,8 @@ def test_run_init_name_length_exceed_limit(sagemaker_session):
122122

123123
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
124124
@patch(
125-
"sagemaker.experiments.run._Experiment._load_or_create",
126-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
125+
"sagemaker.experiments.run.Experiment._load_or_create",
126+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
127127
)
128128
@patch(
129129
"sagemaker.experiments.run._Trial._load_or_create",
@@ -213,8 +213,8 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak
213213

214214
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
215215
@patch(
216-
"sagemaker.experiments.run._Experiment._load_or_create",
217-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
216+
"sagemaker.experiments.run.Experiment._load_or_create",
217+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
218218
)
219219
@patch(
220220
"sagemaker.experiments.run._Trial._load_or_create",
@@ -259,8 +259,8 @@ def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):
259259

260260

261261
@patch(
262-
"sagemaker.experiments.run._Experiment._load_or_create",
263-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
262+
"sagemaker.experiments.run.Experiment._load_or_create",
263+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
264264
)
265265
@patch(
266266
"sagemaker.experiments.run._Trial._load_or_create",
@@ -300,8 +300,8 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
300300

301301

302302
@patch(
303-
"sagemaker.experiments.run._Experiment._load_or_create",
304-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
303+
"sagemaker.experiments.run.Experiment._load_or_create",
304+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
305305
)
306306
@patch(
307307
"sagemaker.experiments.run._Trial._load_or_create",
@@ -712,8 +712,8 @@ def test_log_roc_curve_invalid_input(run_obj):
712712

713713

714714
@patch(
715-
"sagemaker.experiments.run._Experiment._load_or_create",
716-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
715+
"sagemaker.experiments.run.Experiment._load_or_create",
716+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
717717
)
718718
@patch(
719719
"sagemaker.experiments.run._Trial._load_or_create",
@@ -829,8 +829,8 @@ def test_list_empty(mock_tc_list, sagemaker_session):
829829

830830

831831
@patch(
832-
"sagemaker.experiments.run._Experiment._load_or_create",
833-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
832+
"sagemaker.experiments.run.Experiment._load_or_create",
833+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
834834
)
835835
@patch(
836836
"sagemaker.experiments.run._Trial._load_or_create",

tests/unit/sagemaker/experiments/test_run_context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from sagemaker import Processor
2020
from sagemaker.estimator import Estimator, _TrainingJob
21-
from sagemaker.experiments.experiment import _Experiment
21+
from sagemaker.experiments.experiment import Experiment
2222
from sagemaker.experiments.run import _RunContext
2323
from sagemaker.experiments import load_run, Run
2424
from sagemaker.experiments.trial import _Trial
@@ -62,7 +62,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker
6262

6363
@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
6464
@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
65-
@patch("sagemaker.experiments.run._Experiment._load_or_create")
65+
@patch("sagemaker.experiments.run.Experiment._load_or_create")
6666
@patch("sagemaker.experiments.run._Trial._load_or_create")
6767
@patch("sagemaker.experiments.run._TrialComponent._load_or_create")
6868
@patch.object(_TrainingJob, "start_new")
@@ -249,8 +249,8 @@ def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session):
249249

250250

251251
@patch(
252-
"sagemaker.experiments.run._Experiment._load_or_create",
253-
MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
252+
"sagemaker.experiments.run.Experiment._load_or_create",
253+
MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)),
254254
)
255255
@patch(
256256
"sagemaker.experiments.run._Trial._load_or_create",
@@ -305,7 +305,7 @@ def test_nested_run_load_context(run_obj, sagemaker_session):
305305

306306

307307
@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
308-
@patch("sagemaker.experiments.run._Experiment._load_or_create")
308+
@patch("sagemaker.experiments.run.Experiment._load_or_create")
309309
@patch("sagemaker.experiments.run._Trial._load_or_create")
310310
@patch("sagemaker.experiments.run._TrialComponent._load_or_create")
311311
def test_run_init_under_run_load_context(

0 commit comments

Comments
 (0)