Skip to content

Commit 02f0c3c

Browse files
qidewenwhenDewen Qidanabensyzhu0
authored andcommitted
feature: Add SageMaker Experiment (#3536)
* feature: Add experiment plus Run class (#691) * feature: Add Experiment helper classes (#646) * feature: Add Experiment helper classes feature: Add helper class _RunEnvironment * change: Change sleep retry to backoff retry for get TC * minor fixes in backoff retry Co-authored-by: Dewen Qi <[email protected]> * feature: Add helper classes and methods for Run class (#660) * feature: Add helper classes and methods for Run class * Add Parent class to address comment * fix docstyle check * Add arg docstrings in _helper Co-authored-by: Dewen Qi <[email protected]> * feature: Add Experiment Run class (#651) Co-authored-by: Dewen Qi <[email protected]> * change: Add integ tests for Run (#673) Co-authored-by: Dewen Qi <[email protected]> * Update run log metric to use MetricsManager (#678) * Update run.log_metric to use _MetricsManager * fix several metrics issues * Add doc strings to metrics.py Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> * change: Simplify exp plus integ test configuration (#694) Co-authored-by: Dewen Qi <[email protected]> * feature: add RunName to expeirment_config (#696) * change: Update Run init and add Run load and _RunContext (#707) * change: Update Run init and add Run load Add exp name and run group name to load and address comments * Address nit comments Co-authored-by: Dewen Qi <[email protected]> * fix: Fix run name uniqueness issue (#730) Co-authored-by: Dewen Qi <[email protected]> * change: Update integ tests for Exp Plus M1 changes (#741) Co-authored-by: Dewen Qi <[email protected]> * add metrics client to session object (#745) Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: qidewenwhen <[email protected]> * change: Add integ test for using Run in Transform Job (#749) Co-authored-by: Dewen Qi <[email protected]> * Add async metrics sink (#739) Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: qidewenwhen <[email protected]> * use metrics client provided by session (#754) * fix flaky metrics test (#753) * change: Change Run.init and Run.load to constructor and module method respectively (#752) Co-authored-by: Dewen Qi <[email protected]> * feature: Add latest metric service model (#757) Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: qidewenwhen <[email protected]> * fix: lowercase run name (#767) * Change: Minimize use of lower case tc name (#769) * change: Clean up test resources to remove model files (#756) * change: Clean up test resources to remove model files * fix: Change experiment enums to upper case * change: Upgrade boto3 and update test to validate mixed case name * fix: Update as per latest botocore release and backend change Co-authored-by: Dewen Qi <[email protected]> * lowercase trial component name (#776) * change: Expose sagemaker experiment doc strings * fix: Fix exp name mixed case in issue Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Yifei Zhu <[email protected]>
1 parent 6ae3ddb commit 02f0c3c

File tree

82 files changed

+7894
-263
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+7894
-263
lines changed

.gitignore

+3-2
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,6 @@ env/
3030
.vscode/
3131
**/tmp
3232
.python-version
33-
**/_repack_model.py
34-
**/_repack_script_launcher.sh
33+
**/_repack_script_launcher.sh
34+
tests/data/**/_repack_model.py
35+
tests/data/experiment/sagemaker-dev-1.0.tar.gz

doc/experiments/index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
############################
2+
Amazon SageMaker Experiments
3+
############################
4+
5+
The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally.
6+
7+
.. toctree::
8+
:maxdepth: 2
9+
10+
sagemaker.experiments
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Experiments
2+
============
3+
4+
Run
5+
-------------
6+
7+
.. autoclass:: sagemaker.experiments.Run
8+
:members:
9+
10+
.. automethod:: sagemaker.experiments.load_run
11+
12+
.. automethod:: sagemaker.experiments.list_runs
13+
14+
.. autoclass:: sagemaker.experiments.SortByType
15+
:members:
16+
:undoc-members:
17+
18+
.. autoclass:: sagemaker.experiments.SortOrderType
19+
:members:
20+
:undoc-members:

doc/index.rst

+10
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub
6060
workflows/index
6161

6262

63+
****************************
64+
Amazon SageMaker Experiments
65+
****************************
66+
You can use Amazon SageMaker Experiments to track machine learning experiments.
67+
68+
.. toctree::
69+
:maxdepth: 2
70+
71+
experiments/index
72+
6373
*************************
6474
Amazon SageMaker Debugger
6575
*************************

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ requests==2.27.1
2020
sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
23+
scikit-learn==1.0.2

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def read_requirements(filename):
4848
# Declare minimal set for installation
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
51-
"boto3>=1.26.20,<2.0",
51+
"boto3>=1.26.28,<2.0",
5252
"google-pasta",
5353
"numpy>=1.9.0,<2.0",
5454
"protobuf>=3.1,<4.0",

src/sagemaker/amazon/amazon_estimator.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sagemaker.deprecations import renamed_warning
2828
from sagemaker.estimator import EstimatorBase, _TrainingJob
2929
from sagemaker.inputs import FileSystemInput, TrainingInput
30-
from sagemaker.utils import sagemaker_timestamp
30+
from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config
3131
from sagemaker.workflow.entities import PipelineVariable
3232
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3333
from sagemaker.workflow import is_pipeline_variable
@@ -242,8 +242,8 @@ def fit(
242242
generates a default job name, based on the training image name
243243
and current timestamp.
244244
experiment_config (dict[str, str]): Experiment management configuration.
245-
Optionally, the dict can contain three keys:
246-
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
245+
Optionally, the dict can contain four keys:
246+
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
247247
The behavior of setting these keys is as follows:
248248
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
249249
automatically created and the job's Trial Component associated with the Trial.
@@ -255,6 +255,7 @@ def fit(
255255
"""
256256
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
257257

258+
experiment_config = check_and_get_run_experiment_config(experiment_config)
258259
self.latest_training_job = _TrainingJob.start_new(
259260
self, records, experiment_config=experiment_config
260261
)

src/sagemaker/apiutils/_base_types.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ def _search(
173173
search_items = search_method_response.get("Results", [])
174174
next_token = search_method_response.get(boto_next_token_name)
175175
for item in search_items:
176-
if cls.__name__ in item:
177-
yield search_item_factory(item[cls.__name__])
176+
# _TrialComponent class in experiments module is not public currently
177+
class_name = cls.__name__.lstrip("_")
178+
if class_name in item:
179+
yield search_item_factory(item[class_name])
178180
if not next_token:
179181
break
180182
except StopIteration:

src/sagemaker/apiutils/_boto_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type):
6868
api_type, is_collection = member_name_to_type[member_name]
6969
if is_collection:
7070
if isinstance(boto_value, dict):
71-
member_value = api_type.from_boto(boto_value)
71+
member_value = {
72+
key: api_type.from_boto(value) for key, value in boto_value.items()
73+
}
7274
else:
7375
member_value = [api_type.from_boto(item) for item in boto_value]
7476
else:

src/sagemaker/dataset_definition/inputs.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject):
124124
"""DatasetDefinition input."""
125125

126126
_custom_boto_types = {
127-
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
128-
"athena_dataset_definition": (AthenaDatasetDefinition, True),
127+
# RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection
128+
# Instead they are singleton objects. Thus, set the is_collection flag to False.
129+
"redshift_dataset_definition": (RedshiftDatasetDefinition, False),
130+
"athena_dataset_definition": (AthenaDatasetDefinition, False),
129131
}
130132

131133
def __init__(

src/sagemaker/estimator.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
get_config_value,
8080
name_from_base,
8181
to_string,
82+
check_and_get_run_experiment_config,
8283
)
8384
from sagemaker.workflow import is_pipeline_variable
8485
from sagemaker.workflow.entities import PipelineVariable
@@ -1103,8 +1104,8 @@ def fit(
11031104
job_name (str): Training job name. If not specified, the estimator generates
11041105
a default job name based on the training image name and current timestamp.
11051106
experiment_config (dict[str, str]): Experiment management configuration.
1106-
Optionally, the dict can contain three keys:
1107-
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1107+
Optionally, the dict can contain four keys:
1108+
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'..
11081109
The behavior of setting these keys is as follows:
11091110
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
11101111
automatically created and the job's Trial Component associated with the Trial.
@@ -1122,6 +1123,7 @@ def fit(
11221123
"""
11231124
self._prepare_for_training(job_name=job_name)
11241125

1126+
experiment_config = check_and_get_run_experiment_config(experiment_config)
11251127
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
11261128
self.jobs.append(self.latest_training_job)
11271129
if wait:
@@ -2023,8 +2025,8 @@ def start_new(cls, estimator, inputs, experiment_config):
20232025
inputs (str): Parameters used when called
20242026
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
20252027
experiment_config (dict[str, str]): Experiment management configuration.
2026-
Optionally, the dict can contain three keys:
2027-
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
2028+
Optionally, the dict can contain four keys:
2029+
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
20282030
The behavior of setting these keys is as follows:
20292031
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
20302032
automatically created and the job's Trial Component associated with the Trial.
@@ -2033,6 +2035,7 @@ def start_new(cls, estimator, inputs, experiment_config):
20332035
* If both `ExperimentName` and `TrialName` are not supplied the trial component
20342036
will be unassociated.
20352037
* `TrialComponentDisplayName` is used for display in Studio.
2038+
* `RunName` is used to record an experiment run.
20362039
Returns:
20372040
sagemaker.estimator._TrainingJob: Constructed object that captures
20382041
all information about the started training job.
@@ -2053,8 +2056,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
20532056
inputs (str): Parameters used when called
20542057
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
20552058
experiment_config (dict[str, str]): Experiment management configuration.
2056-
Optionally, the dict can contain three keys:
2057-
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
2059+
Optionally, the dict can contain four keys:
2060+
'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
20582061
The behavior of setting these keys is as follows:
20592062
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
20602063
automatically created and the job's Trial Component associated with the Trial.
@@ -2063,6 +2066,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
20632066
* If both `ExperimentName` and `TrialName` are not supplied the trial component
20642067
will be unassociated.
20652068
* `TrialComponentDisplayName` is used for display in Studio.
2069+
* `RunName` is used to record an experiment run.
20662070
20672071
Returns:
20682072
Dict: dict for `sagemaker.session.Session.train` method

src/sagemaker/experiments/__init__.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker Experiment Module"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.experiments.run import Run # noqa: F401
17+
from sagemaker.experiments.run import load_run # noqa: F401
18+
from sagemaker.experiments.run import list_runs # noqa: F401
19+
from sagemaker.experiments.run import SortOrderType # noqa: F401
20+
from sagemaker.experiments.run import SortByType # noqa: F401

0 commit comments

Comments
 (0)