diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst index 045b868f19..148dd00284 100644 --- a/doc/experiments/sagemaker.experiments.rst +++ b/doc/experiments/sagemaker.experiments.rst @@ -11,6 +11,15 @@ Run .. automethod:: sagemaker.experiments.list_runs +Experiment +------------- + +.. autoclass:: sagemaker.experiments.Experiment + :members: + +Other +------------- + .. autoclass:: sagemaker.experiments.SortByType :members: :undoc-members: diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py index b87656b1ab..ae9616ae3a 100644 --- a/src/sagemaker/experiments/__init__.py +++ b/src/sagemaker/experiments/__init__.py @@ -14,6 +14,7 @@ from __future__ import absolute_import from sagemaker.experiments.run import Run # noqa: F401 +from sagemaker.experiments.experiment import Experiment # noqa: F401 from sagemaker.experiments.run import load_run # noqa: F401 from sagemaker.experiments.run import list_runs # noqa: F401 from sagemaker.experiments.run import SortOrderType # noqa: F401 diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py index 824734d294..584fbed27e 100644 --- a/src/sagemaker/experiments/experiment.py +++ b/src/sagemaker/experiments/experiment.py @@ -22,11 +22,11 @@ from sagemaker.experiments.trial_component import _TrialComponent -class _Experiment(_base_types.Record): +class Experiment(_base_types.Record): """An Amazon SageMaker experiment, which is a collection of related trials. - New experiments are created by calling `experiments.experiment._Experiment.create`. - Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`. + New experiments are created by calling `experiments.experiment.Experiment.create`. + Existing experiments can be reloaded by calling `experiments.experiment.Experiment.load`. Attributes: experiment_name (str): The name of the experiment. The name must be unique @@ -73,7 +73,7 @@ def delete(self): @classmethod def load(cls, experiment_name, sagemaker_session=None): - """Load an existing experiment and return an `_Experiment` object representing it. + """Load an existing experiment and return an `Experiment` object representing it. Args: experiment_name: (str): Name of the experiment @@ -83,7 +83,7 @@ def load(cls, experiment_name, sagemaker_session=None): default AWS configuration chain. Returns: - experiments.experiment._Experiment: A SageMaker `_Experiment` object + experiments.experiment.Experiment: A SageMaker `Experiment` object """ return cls._construct( cls._boto_load_method, @@ -100,7 +100,7 @@ def create( tags=None, sagemaker_session=None, ): - """Create a new experiment in SageMaker and return an `_Experiment` object. + """Create a new experiment in SageMaker and return an `Experiment` object. Args: experiment_name: (str): Name of the experiment. Must be unique. Required. @@ -115,7 +115,7 @@ def create( (default: None). Returns: - experiments.experiment._Experiment: A SageMaker `_Experiment` object + experiments.experiment.Experiment: A SageMaker `Experiment` object """ return cls._construct( cls._boto_create_method, @@ -154,10 +154,10 @@ def _load_or_create( exist and a new experiment has to be created. Returns: - experiments.experiment._Experiment: A SageMaker `_Experiment` object + experiments.experiment.Experiment: A SageMaker `Experiment` object """ try: - experiment = _Experiment.create( + experiment = Experiment.create( experiment_name=experiment_name, display_name=display_name, description=description, @@ -170,7 +170,7 @@ def _load_or_create( if not (error_code == "ValidationException" and "already exists" in error_message): raise ce # already exists - experiment = _Experiment.load(experiment_name, sagemaker_session) + experiment = Experiment.load(experiment_name, sagemaker_session) return experiment def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None): diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py index 07b7080ea3..6202de858c 100644 --- a/src/sagemaker/experiments/run.py +++ b/src/sagemaker/experiments/run.py @@ -32,7 +32,7 @@ ) from sagemaker.experiments._environment import _RunEnvironment from sagemaker.experiments._run_context import _RunContext -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments._metrics import _MetricsManager from sagemaker.experiments.trial import _Trial from sagemaker.experiments.trial_component import _TrialComponent @@ -166,7 +166,7 @@ def __init__( ) self.run_group_name = Run._generate_trial_name(self.experiment_name) - self._experiment = _Experiment._load_or_create( + self._experiment = Experiment._load_or_create( experiment_name=self.experiment_name, display_name=experiment_display_name, tags=tags, diff --git a/tests/integ/sagemaker/experiments/conftest.py b/tests/integ/sagemaker/experiments/conftest.py index ca40a3ba6d..95c0f3561d 100644 --- a/tests/integ/sagemaker/experiments/conftest.py +++ b/tests/integ/sagemaker/experiments/conftest.py @@ -46,7 +46,7 @@ def run_obj(sagemaker_session): yield run time.sleep(0.5) finally: - exp = experiment._Experiment.load( + exp = experiment.Experiment.load( experiment_name=run.experiment_name, sagemaker_session=sagemaker_session ) exp._delete_all(action="--force") @@ -71,7 +71,7 @@ def experiment_obj(sagemaker_session): description = "{}-{}".format("description", str(uuid.uuid4())) boto3.set_stream_logger("", logging.INFO) experiment_name = name() - experiment_obj = experiment._Experiment.create( + experiment_obj = experiment.Experiment.create( experiment_name=experiment_name, description=description, sagemaker_session=sagemaker_session, diff --git a/tests/integ/sagemaker/experiments/helpers.py b/tests/integ/sagemaker/experiments/helpers.py index b5e8064b08..9a22c3a30c 100644 --- a/tests/integ/sagemaker/experiments/helpers.py +++ b/tests/integ/sagemaker/experiments/helpers.py @@ -15,7 +15,7 @@ from contextlib import contextmanager from sagemaker import utils -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ" @@ -38,5 +38,5 @@ def cleanup_exp_resources(exp_names, sagemaker_session): yield finally: for exp_name in exp_names: - exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session) + exp = Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session) exp._delete_all(action="--force") diff --git a/tests/integ/sagemaker/experiments/test_experiment.py b/tests/integ/sagemaker/experiments/test_experiment.py index ff7d5fac37..1a85de047f 100644 --- a/tests/integ/sagemaker/experiments/test_experiment.py +++ b/tests/integ/sagemaker/experiments/test_experiment.py @@ -40,7 +40,7 @@ def test_save(experiment_obj): def test_save_load(experiment_obj, sagemaker_session): - experiment_obj_two = experiment._Experiment.load( + experiment_obj_two = experiment.Experiment.load( experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session ) assert experiment_obj.experiment_name == experiment_obj_two.experiment_name @@ -49,7 +49,7 @@ def test_save_load(experiment_obj, sagemaker_session): experiment_obj.description = name() experiment_obj.display_name = name() experiment_obj.save() - experiment_obj_three = experiment._Experiment.load( + experiment_obj_three = experiment.Experiment.load( experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session ) assert experiment_obj.description == experiment_obj_three.description diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py index 4d33ad759d..2fcd114a55 100644 --- a/tests/unit/sagemaker/experiments/conftest.py +++ b/tests/unit/sagemaker/experiments/conftest.py @@ -18,7 +18,7 @@ import pytest from sagemaker import Session -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments.run import RUN_NAME_BASE from sagemaker.experiments import Run from tests.unit.sagemaker.experiments.helpers import ( @@ -57,9 +57,9 @@ def run_obj(sagemaker_session): client.update_trial_component.return_value = {} client.associate_trial_component.return_value = {} with patch( - "sagemaker.experiments.run._Experiment._load_or_create", + "sagemaker.experiments.run.Experiment._load_or_create", MagicMock( - return_value=_Experiment( + return_value=Experiment( experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session ) ), diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py index e6cac54a92..225338f70b 100644 --- a/tests/unit/sagemaker/experiments/test_experiment.py +++ b/tests/unit/sagemaker/experiments/test_experiment.py @@ -32,7 +32,7 @@ def datetime_obj(): def test_load(sagemaker_session): client = sagemaker_session.sagemaker_client client.describe_experiment.return_value = {"Description": "description-value"} - experiment_obj = experiment._Experiment.load( + experiment_obj = experiment.Experiment.load( experiment_name="name-value", sagemaker_session=sagemaker_session ) assert experiment_obj.experiment_name == "name-value" @@ -44,7 +44,7 @@ def test_load(sagemaker_session): def test_create(sagemaker_session): client = sagemaker_session.sagemaker_client client.create_experiment.return_value = {"Arn": "arn:aws:1234"} - experiment_obj = experiment._Experiment.create( + experiment_obj = experiment.Experiment.create( experiment_name="name-value", sagemaker_session=sagemaker_session ) assert experiment_obj.experiment_name == "name-value" @@ -55,7 +55,7 @@ def test_create_with_tags(sagemaker_session): client = sagemaker_session.sagemaker_client client.create_experiment.return_value = {"Arn": "arn:aws:1234"} tags = [{"Key": "foo", "Value": "bar"}] - experiment_obj = experiment._Experiment.create( + experiment_obj = experiment.Experiment.create( experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags ) assert experiment_obj.experiment_name == "name-value" @@ -64,7 +64,7 @@ def test_create_with_tags(sagemaker_session): def test_save(sagemaker_session): client = sagemaker_session.sagemaker_client - obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar") client.update_experiment.return_value = {} obj.save() client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar") @@ -72,14 +72,14 @@ def test_save(sagemaker_session): def test_delete(sagemaker_session): client = sagemaker_session.sagemaker_client - obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar") client.delete_experiment.return_value = {} obj.delete() client.delete_experiment.assert_called_with(ExperimentName="foo") -@patch("sagemaker.experiments.experiment._Experiment.load") -@patch("sagemaker.experiments.experiment._Experiment.create") +@patch("sagemaker.experiments.experiment.Experiment.load") +@patch("sagemaker.experiments.experiment.Experiment.create") def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session): exp_name = "exp_name" exists_error = botocore.exceptions.ClientError( @@ -92,7 +92,7 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session): operation_name="foo", ) mock_create.side_effect = exists_error - experiment._Experiment._load_or_create( + experiment.Experiment._load_or_create( experiment_name=exp_name, sagemaker_session=sagemaker_session ) mock_create.assert_called_once_with( @@ -105,12 +105,12 @@ def test_load_or_create_when_exist(mock_create, mock_load, sagemaker_session): mock_load.assert_called_once_with(exp_name, sagemaker_session) -@patch("sagemaker.experiments.experiment._Experiment.load") -@patch("sagemaker.experiments.experiment._Experiment.create") +@patch("sagemaker.experiments.experiment.Experiment.load") +@patch("sagemaker.experiments.experiment.Experiment.create") def test_load_or_create_when_not_exist(mock_create, mock_load): sagemaker_session = Session() exp_name = "exp_name" - experiment._Experiment._load_or_create( + experiment.Experiment._load_or_create( experiment_name=exp_name, sagemaker_session=sagemaker_session ) mock_create.assert_called_once_with( @@ -125,12 +125,12 @@ def test_load_or_create_when_not_exist(mock_create, mock_load): def test_list_trials_empty(sagemaker_session): sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []} - experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session) assert list(experiment_obj.list_trials()) == [] def test_list_trials_single(sagemaker_session, datetime_obj): - experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session) sagemaker_session.sagemaker_client.list_trials.return_value = { "TrialSummaries": [ {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj} @@ -143,7 +143,7 @@ def test_list_trials_single(sagemaker_session, datetime_obj): def test_list_trials_two_values(sagemaker_session, datetime_obj): - experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session) sagemaker_session.sagemaker_client.list_trials.return_value = { "TrialSummaries": [ {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}, @@ -162,7 +162,7 @@ def test_list_trials_two_values(sagemaker_session, datetime_obj): def test_next_token(sagemaker_session, datetime_obj): - experiment_obj = experiment._Experiment(sagemaker_session) + experiment_obj = experiment.Experiment(sagemaker_session) client = sagemaker_session.sagemaker_client client.list_trials.side_effect = [ { @@ -211,7 +211,7 @@ def test_list_trials_call_args(sagemaker_session): client = sagemaker_session.sagemaker_client created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) - experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session) + experiment_obj = experiment.Experiment(sagemaker_session=sagemaker_session) client.list_trials.return_value = {} assert [] == list( experiment_obj.list_trials(created_after=created_after, created_before=created_before) @@ -220,7 +220,7 @@ def test_list_trials_call_args(sagemaker_session): def test_delete_all_with_incorrect_action_name(sagemaker_session): - obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar") with pytest.raises(ValueError) as err: obj._delete_all(action="abc") @@ -228,7 +228,7 @@ def test_delete_all_with_incorrect_action_name(sagemaker_session): def test_delete_all(sagemaker_session): - obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar") client = sagemaker_session.sagemaker_client client.list_trials.return_value = { "TrialSummaries": [ @@ -310,7 +310,7 @@ def test_delete_all(sagemaker_session): def test_delete_all_fail(sagemaker_session): - obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar") + obj = experiment.Experiment(sagemaker_session, experiment_name="foo", description="bar") sagemaker_session.sagemaker_client.list_trials.side_effect = Exception with pytest.raises(Exception) as e: obj._delete_all(action="--force") diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py index a6495fc914..7f54fe8d6f 100644 --- a/tests/unit/sagemaker/experiments/test_run.py +++ b/tests/unit/sagemaker/experiments/test_run.py @@ -29,7 +29,7 @@ _TrialComponentStatusType, TrialComponentSearchResult, ) -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments.run import ( TRIAL_NAME_TEMPLATE, MAX_RUN_TC_ARTIFACTS_LEN, @@ -55,8 +55,8 @@ @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -125,8 +125,8 @@ def test_run_init_name_length_exceed_limit(sagemaker_session): @patch.object(_TrialComponent, "save", MagicMock(return_value=None)) @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -216,8 +216,8 @@ def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemak @patch.object(_TrialComponent, "save", MagicMock(return_value=None)) @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -262,8 +262,8 @@ def test_run_load_with_run_name_but_no_exp_name(sagemaker_session): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -303,8 +303,8 @@ def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -344,8 +344,8 @@ def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -743,8 +743,8 @@ def test_log_roc_curve_invalid_input(run_obj): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -860,8 +860,8 @@ def test_list_empty(mock_tc_list, sagemaker_session): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py index e63a1256a5..7026c48f41 100644 --- a/tests/unit/sagemaker/experiments/test_run_context.py +++ b/tests/unit/sagemaker/experiments/test_run_context.py @@ -18,7 +18,7 @@ from sagemaker import Processor from sagemaker.estimator import Estimator, _TrainingJob -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments.run import _RunContext from sagemaker.experiments import load_run, Run from sagemaker.experiments.trial import _Trial @@ -62,7 +62,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker @patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) @patch.object(_TrialComponent, "save", MagicMock(return_value=None)) -@patch("sagemaker.experiments.run._Experiment._load_or_create") +@patch("sagemaker.experiments.run.Experiment._load_or_create") @patch("sagemaker.experiments.run._Trial._load_or_create") @patch("sagemaker.experiments.run._TrialComponent._load_or_create") @patch.object(_TrainingJob, "start_new") @@ -249,8 +249,8 @@ def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session): @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", @@ -305,7 +305,7 @@ def test_nested_run_load_context(run_obj, sagemaker_session): @patch.object(_Trial, "add_trial_component", MagicMock(return_value=None)) -@patch("sagemaker.experiments.run._Experiment._load_or_create") +@patch("sagemaker.experiments.run.Experiment._load_or_create") @patch("sagemaker.experiments.run._Trial._load_or_create") @patch("sagemaker.experiments.run._TrialComponent._load_or_create") def test_run_init_under_run_load_context( diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index e48137e30a..eb06cf5cc4 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -89,7 +89,7 @@ def test_serialize_deserialize_lambda(): @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) -@patch("sagemaker.experiments.run._Experiment") +@patch("sagemaker.experiments.run.Experiment") @patch("sagemaker.experiments.run._Trial") @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") @@ -221,7 +221,7 @@ def test_serialize_deserialize_none(): @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read) -@patch("sagemaker.experiments.run._Experiment") +@patch("sagemaker.experiments.run.Experiment") @patch("sagemaker.experiments.run._Trial") @patch("sagemaker.experiments.run._TrialComponent._load_or_create", return_value=(Mock(), False)) @patch("sagemaker.experiments.run._MetricsManager") diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index 759f06f7cf..0b4008ef41 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -16,7 +16,7 @@ import random import string from mock import MagicMock, Mock, patch -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments.run import Run from sagemaker.experiments.trial import _Trial from sagemaker.experiments.trial_component import _TrialComponent @@ -86,8 +86,8 @@ def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwarg @patch( - "sagemaker.experiments.run._Experiment._load_or_create", - MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)), + "sagemaker.experiments.run.Experiment._load_or_create", + MagicMock(return_value=Experiment(experiment_name=TEST_EXP_NAME)), ) @patch( "sagemaker.experiments.run._Trial._load_or_create", diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 6d4541a0df..b5ccfc7568 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -21,7 +21,7 @@ from botocore.exceptions import ClientError from sagemaker import Session -from sagemaker.experiments.experiment import _Experiment +from sagemaker.experiments.experiment import Experiment from sagemaker.experiments.run import Run from sagemaker.remote_function.client import ( remote, @@ -98,11 +98,9 @@ def run_obj(sagemaker_session): client.update_trial_component.return_value = {} client.associate_trial_component.return_value = {} with patch( - "sagemaker.experiments.run._Experiment._load_or_create", + "sagemaker.experiments.run.Experiment._load_or_create", MagicMock( - return_value=_Experiment( - experiment_name="test-exp", sagemaker_session=sagemaker_session - ) + return_value=Experiment(experiment_name="test-exp", sagemaker_session=sagemaker_session) ), ): with patch(