Skip to content

Commit e1efdb5

Browse files
qidewenwhenDewen Qidanabens
committed
feature: Add experiment plus Run class (aws#691)
* feature: Add Experiment helper classes (aws#646) * feature: Add Experiment helper classes feature: Add helper class _RunEnvironment * change: Change sleep retry to backoff retry for get TC * minor fixes in backoff retry Co-authored-by: Dewen Qi <[email protected]> * feature: Add helper classes and methods for Run class (aws#660) * feature: Add helper classes and methods for Run class * Add Parent class to address comment * fix docstyle check * Add arg docstrings in _helper Co-authored-by: Dewen Qi <[email protected]> * feature: Add Experiment Run class (aws#651) Co-authored-by: Dewen Qi <[email protected]> * change: Add integ tests for Run (aws#673) Co-authored-by: Dewen Qi <[email protected]> * Update run log metric to use MetricsManager (aws#678) * Update run.log_metric to use _MetricsManager * fix several metrics issues * Add doc strings to metrics.py Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]>
1 parent bd96ec5 commit e1efdb5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+41403
-63
lines changed

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,7 @@ env/
3131
**/tmp
3232
.python-version
3333
**/_repack_model.py
34-
**/_repack_script_launcher.sh
34+
**/_repack_script_launcher.sh
35+
tests/data/experiment/docker/boto
36+
tests/data/experiment/docker/sagemaker-dev.tar.gz
37+
tests/data/**/_repack_model.py

requirements/extras/test_requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ requests==2.27.1
2020
sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
23+
scikit-learn==1.0.2

src/sagemaker/apiutils/_base_types.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ def _search(
173173
search_items = search_method_response.get("Results", [])
174174
next_token = search_method_response.get(boto_next_token_name)
175175
for item in search_items:
176-
if cls.__name__ in item:
177-
yield search_item_factory(item[cls.__name__])
176+
# _TrialComponent class in experiments module is not public currently
177+
class_name = cls.__name__.lstrip("_")
178+
if class_name in item:
179+
yield search_item_factory(item[class_name])
178180
if not next_token:
179181
break
180182
except StopIteration:

src/sagemaker/apiutils/_boto_functions.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type):
6868
api_type, is_collection = member_name_to_type[member_name]
6969
if is_collection:
7070
if isinstance(boto_value, dict):
71-
member_value = api_type.from_boto(boto_value)
71+
member_value = {
72+
key: api_type.from_boto(value) for key, value in boto_value.items()
73+
}
7274
else:
7375
member_value = [api_type.from_boto(item) for item in boto_value]
7476
else:

src/sagemaker/dataset_definition/inputs.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject):
124124
"""DatasetDefinition input."""
125125

126126
_custom_boto_types = {
127-
"redshift_dataset_definition": (RedshiftDatasetDefinition, True),
128-
"athena_dataset_definition": (AthenaDatasetDefinition, True),
127+
# RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection
128+
# Instead they are singleton objects. Thus, set the is_collection flag to False.
129+
"redshift_dataset_definition": (RedshiftDatasetDefinition, False),
130+
"athena_dataset_definition": (AthenaDatasetDefinition, False),
129131
}
130132

131133
def __init__(

src/sagemaker/experiments/__init__.py

Whitespace-only changes.
+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Contains API objects for SageMaker experiments."""
14+
from __future__ import absolute_import
15+
16+
import numbers
17+
18+
from sagemaker.apiutils import _base_types
19+
20+
21+
class TrialComponentMetricSummary(_base_types.ApiObject):
22+
"""Summary model of a trial component.
23+
24+
Attributes:
25+
metric_name (str): The name of the metric.
26+
source_arn (str): The ARN of the source.
27+
time_stamp (datetime): Metric last updated value.
28+
max (float): The max value of the metric.
29+
min (float): The min value of the metric.
30+
last (float): The last value of the metric.
31+
count (float): The number of samples used to generate the metric.
32+
avg (float): The average value of the metric.
33+
std_dev (float): The standard deviation of the metric.
34+
"""
35+
36+
metric_name = None
37+
source_arn = None
38+
time_stamp = None
39+
max = None
40+
min = None
41+
last = None
42+
count = None
43+
avg = None
44+
std_dev = None
45+
46+
def __init__(self, metric_name=None, source_arn=None, **kwargs):
47+
super(TrialComponentMetricSummary, self).__init__(
48+
metric_name=metric_name, source_arn=source_arn, **kwargs
49+
)
50+
51+
52+
class TrialComponentParameters(_base_types.ApiObject):
53+
"""A dictionary of TrialComponentParameterValues"""
54+
55+
@classmethod
56+
def from_boto(cls, boto_dict, **kwargs):
57+
"""Converts a boto dict to a dictionary of TrialComponentParameterValues
58+
59+
Args:
60+
boto_dict (dict): boto response dictionary.
61+
**kwargs: Arbitrary keyword arguments.
62+
63+
Returns:
64+
dict: Dictionary of parameter values.
65+
"""
66+
return_map = {}
67+
for key, value in boto_dict.items():
68+
return_map[key] = value.get("NumberValue", value.get("StringValue", None))
69+
return return_map
70+
71+
@classmethod
72+
def to_boto(cls, parameters):
73+
"""Converts TrialComponentParameters to dict.
74+
75+
Args:
76+
parameters (TrialComponentParameters): Dictionary to convert.
77+
78+
Returns:
79+
dict: Dictionary of trial component parameters in boto format.
80+
"""
81+
boto_map = {}
82+
for key, value in parameters.items():
83+
if isinstance(value, numbers.Number):
84+
boto_map[key] = {"NumberValue": value}
85+
else:
86+
boto_map[key] = {"StringValue": str(value)}
87+
return boto_map
88+
89+
90+
class TrialComponentArtifact(_base_types.ApiObject):
91+
"""Trial component artifact.
92+
93+
Attributes:
94+
value (str): The artifact value.
95+
media_type (str): The media type.
96+
"""
97+
98+
value = None
99+
media_type = None
100+
101+
def __init__(self, value=None, media_type=None, **kwargs):
102+
super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs)
103+
104+
105+
class TrialComponentStatus(_base_types.ApiObject):
106+
"""Status of the trial component.
107+
108+
Attributes:
109+
primary_status (str): The status of a trial component.
110+
message (str): Status message.
111+
"""
112+
113+
primary_status = None
114+
message = None
115+
116+
def __init__(self, primary_status=None, message=None, **kwargs):
117+
super(TrialComponentStatus, self).__init__(
118+
primary_status=primary_status, message=message, **kwargs
119+
)
120+
121+
122+
class TrialComponentSummary(_base_types.ApiObject):
123+
"""Summary model of a trial component.
124+
125+
Attributes:
126+
trial_component_name (str): Name of trial component.
127+
trial_component_arn (str): ARN of the trial component.
128+
display_name (str): Friendly display name in UI.
129+
source_arn (str): ARN of the trial component source.
130+
status (str): Status.
131+
start_time (datetime): Start time.
132+
end_time (datetime): End time.
133+
creation_time (datetime): Creation time.
134+
created_by (str): Created by.
135+
last_modified_time (datetime): Date last modified.
136+
last_modified_by (datetime): User last modified.
137+
"""
138+
139+
_custom_boto_types = {
140+
"status": (TrialComponentStatus, False),
141+
}
142+
trial_component_name = None
143+
trial_component_arn = None
144+
display_name = None
145+
source_arn = None
146+
status = None
147+
start_time = None
148+
end_time = None
149+
creation_time = None
150+
created_by = None
151+
last_modified_time = None
152+
last_modified_by = None
153+
154+
155+
class TrialComponentSource(_base_types.ApiObject):
156+
"""Trial Component Source
157+
158+
Attributes:
159+
source_arn (str): The ARN of the source.
160+
"""
161+
162+
source_arn = None
163+
164+
def __init__(self, source_arn=None, **kwargs):
165+
super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs)
166+
167+
168+
class Parent(_base_types.ApiObject):
169+
"""The trial/experiment/run that a trial component is associated with.
170+
171+
Attributes:
172+
trial_name (str): Name of the trial.
173+
experiment_name (str): Name of the experiment.
174+
run_name (str): Name of the run.
175+
"""
176+
177+
trial_name = None
178+
experiment_name = None
179+
run_name = None
180+
181+
182+
class TrialComponentSearchResult(_base_types.ApiObject):
183+
"""Summary model of an Trial Component search result.
184+
185+
Attributes:
186+
trial_component_arn (str): ARN of the trial component.
187+
trial_component_name (str): Name of the trial component.
188+
display_name (str): Display name of the trial component for UI display.
189+
source (dict): The source of the trial component.
190+
status (dict): The status of the trial component.
191+
start_time (datetime): Start time.
192+
end_time (datetime): End time.
193+
creation_time (datetime): Creation time.
194+
created_by (str): Created by.
195+
last_modified_time (datetime): Date last modified.
196+
last_modified_by (datetime): User last modified.
197+
parameters (dict): The hyperparameters of the component.
198+
input_artifacts (dict): The input artifacts of the component.
199+
output_artifacts (dict): The output artifacts of the component.
200+
metrics (list): The metrics for the component.
201+
source_detail (dict): The source of the trial component.
202+
tags (list): The list of tags that are associated with the trial component.
203+
parents (list[Parent]): The parent of trial component.
204+
"""
205+
206+
_custom_boto_types = {
207+
"parents": (Parent, True), # parents is a collection (list) of Parent objects
208+
}
209+
trial_component_arn = None
210+
trial_component_name = None
211+
display_name = None
212+
source = None
213+
status = None
214+
start_time = None
215+
end_time = None
216+
creation_time = None
217+
created_by = None
218+
last_modified_time = None
219+
last_modified_by = None
220+
parameters = None
221+
input_artifacts = None
222+
output_artifacts = None
223+
metrics = None
224+
source_detail = None
225+
tags = None
226+
parents = None
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Contains the _RunEnvironment class."""
14+
from __future__ import absolute_import
15+
16+
import enum
17+
import json
18+
import logging
19+
import os
20+
21+
from sagemaker.experiments import trial_component
22+
from sagemaker.utils import retry_with_backoff
23+
24+
TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
25+
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
26+
MAX_RETRY_ATTEMPTS = 7
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class EnvironmentType(enum.Enum):
32+
"""SageMaker jobs which data can be pulled from the environment."""
33+
34+
SageMakerTrainingJob = 1
35+
SageMakerProcessingJob = 2
36+
37+
38+
class _RunEnvironment(object):
39+
"""Retrieves job specific data from the environment."""
40+
41+
def __init__(self, environment_type, source_arn):
42+
"""Init for _RunEnvironment.
43+
44+
Args:
45+
environment_type (EnvironmentType): The environment type.
46+
source_arn (str): The ARN of the current job.
47+
"""
48+
self.environment_type = environment_type
49+
self.source_arn = source_arn
50+
51+
@classmethod
52+
def load(
53+
cls,
54+
training_job_arn_env=TRAINING_JOB_ARN_ENV,
55+
processing_job_config_path=PROCESSING_JOB_CONFIG_PATH,
56+
):
57+
"""Loads source arn of current job from environment.
58+
59+
Args:
60+
training_job_arn_env (str): The environment key for training job ARN
61+
(default: `TRAINING_JOB_ARN`).
62+
processing_job_config_path (str): The processing job config path
63+
(default: `/opt/ml/config/processingjobconfig.json`).
64+
65+
Returns:
66+
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
67+
"""
68+
if training_job_arn_env in os.environ:
69+
environment_type = EnvironmentType.SageMakerTrainingJob
70+
source_arn = os.environ.get(training_job_arn_env)
71+
return _RunEnvironment(environment_type, source_arn)
72+
if os.path.exists(processing_job_config_path):
73+
environment_type = EnvironmentType.SageMakerProcessingJob
74+
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
75+
return _RunEnvironment(environment_type, source_arn)
76+
return None
77+
78+
def get_trial_component(self, sagemaker_session):
79+
"""Retrieves the trial component from the job in the environment.
80+
81+
Args:
82+
sagemaker_session (sagemaker.session.Session): Session object which
83+
manages interactions with Amazon SageMaker APIs and any other
84+
AWS services needed. If not specified, one is created using the
85+
default AWS configuration chain.
86+
87+
Returns:
88+
_TrialComponent: The trial component created from the job. None if not found.
89+
"""
90+
91+
def _get_trial_component():
92+
summaries = list(
93+
trial_component._TrialComponent.list(
94+
source_arn=self.source_arn, sagemaker_session=sagemaker_session
95+
)
96+
)
97+
if summaries:
98+
summary = summaries[0]
99+
return trial_component._TrialComponent.load(
100+
trial_component_name=summary.trial_component_name,
101+
sagemaker_session=sagemaker_session,
102+
)
103+
return None
104+
105+
job_tc = None
106+
try:
107+
job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS)
108+
except Exception as ex: # pylint: disable=broad-except
109+
logger.error(
110+
"Failed to get trail component in the current environment due to %s", str(ex)
111+
)
112+
return job_tc

0 commit comments

Comments
 (0)