Skip to content

Commit 0c94a85

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
committed
change: Change Run.init and Run.load to constructor and module method respectively (aws#752)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 50b9f1c commit 0c94a85

File tree

17 files changed

+648
-664
lines changed

17 files changed

+648
-664
lines changed

src/sagemaker/experiments/__init__.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Sagemaker Experiment Module"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.experiments.run import Run, load_run, list_runs, SortOrderType # noqa: F401

src/sagemaker/experiments/_environment.py

+8
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,14 @@ def get_trial_component(self, sagemaker_session):
9999
Returns:
100100
_TrialComponent: The trial component created from the job. None if not found.
101101
"""
102+
# TODO: Remove this condition check once we have a way to retrieve source ARN
103+
# from transform job env
104+
if self.environment_type == _EnvironmentType.SageMakerTransformJob:
105+
logger.error(
106+
"Currently getting the job trial component from the transform job environment "
107+
"is not supported. Returning None."
108+
)
109+
return None
102110

103111
def _get_trial_component():
104112
summaries = list(

src/sagemaker/experiments/_run_context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import TYPE_CHECKING
1717

1818
if TYPE_CHECKING:
19-
from sagemaker.experiments.run import Run
19+
from sagemaker.experiments import Run
2020

2121

2222
class _RunContext:

src/sagemaker/experiments/_utils.py

+123
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,20 @@
1313
"""Contains the SageMaker Experiment utility methods."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
import os
1718

1819
import mimetypes
1920
import urllib
2021
from functools import wraps
22+
from typing import Optional
2123

24+
from sagemaker import Session
2225
from sagemaker.apiutils import _utils
26+
from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType
27+
from sagemaker.experiments.trial_component import _TrialComponent
28+
from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression
29+
from sagemaker.utils import retry_with_backoff
2330

2431

2532
def resolve_artifact_name(file_path):
@@ -93,3 +100,119 @@ def is_already_exist_error(error):
93100
`botocore.exceptions.ClientError`
94101
"""
95102
return error["Code"] == "ValidationException" and "already exists" in error["Message"]
103+
104+
105+
def get_tc_and_exp_config_from_job_env(
106+
environment: _RunEnvironment,
107+
sagemaker_session: Session,
108+
) -> dict:
109+
"""Retrieve an experiment config from the job environment.
110+
111+
Args:
112+
environment (_RunEnvironment): The run environment object with job specific data.
113+
sagemaker_session (sagemaker.session.Session): Session object which
114+
manages interactions with Amazon SageMaker APIs and any other
115+
AWS services needed. If not specified, one is created using the
116+
default AWS configuration chain.
117+
"""
118+
job_name = environment.source_arn.split("/")[-1]
119+
if environment.environment_type == _EnvironmentType.SageMakerTrainingJob:
120+
job_response = retry_with_backoff(
121+
callable_func=lambda: sagemaker_session.describe_training_job(job_name),
122+
num_attempts=4,
123+
)
124+
elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob:
125+
job_response = retry_with_backoff(
126+
callable_func=lambda: sagemaker_session.describe_processing_job(job_name),
127+
num_attempts=4,
128+
)
129+
else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob
130+
raise RuntimeError(
131+
"Failed to load the Run as loading experiment config "
132+
"from transform job environment is not currently supported. "
133+
"As a workaround, please explicitly pass in "
134+
"the experiment_name and run_name in load_run."
135+
)
136+
137+
job_exp_config = job_response.get("ExperimentConfig", dict())
138+
from sagemaker.experiments.run import RUN_NAME
139+
140+
if job_exp_config.get(RUN_NAME, None):
141+
return job_exp_config
142+
raise RuntimeError(
143+
"Not able to fetch RunName in ExperimentConfig of the sagemaker job. "
144+
"Please make sure the ExperimentConfig is correctly set."
145+
)
146+
147+
148+
def verify_load_input_names(
149+
run_name: Optional[str] = None,
150+
experiment_name: Optional[str] = None,
151+
):
152+
"""Verify the run_name and the experiment_name inputs in load_run.
153+
154+
Args:
155+
run_name (str): The run_name supplied by the user (default: None).
156+
experiment_name (str): The experiment_name supplied by the user
157+
(default: None).
158+
159+
Raises:
160+
ValueError: If run_name is supplied while experiment_name is not.
161+
"""
162+
if not run_name and experiment_name:
163+
logging.warning(
164+
"No run_name is supplied. Ignoring the provided experiment_name "
165+
"since it only takes effect along with run_name. "
166+
"Will load the Run object from the job environment or current Run context."
167+
)
168+
if run_name and not experiment_name:
169+
raise ValueError(
170+
"Invalid input: experiment_name is missing when run_name is supplied. "
171+
"Please supply a valid experiment_name when the run_name is not None."
172+
)
173+
174+
175+
def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool:
176+
"""Check if a trial component is generated by `sagemaker.experiments.Run`
177+
178+
Args:
179+
trial_component_name (str): The name of the trial component.
180+
sagemaker_session (sagemaker.session.Session): Session object which
181+
manages interactions with Amazon SageMaker APIs and any other
182+
AWS services needed. If not specified, one is created using the
183+
default AWS configuration chain.
184+
185+
Returns:
186+
bool: Indicate whether the trial component is created by
187+
`sagemaker.experiments.Run` or not.
188+
"""
189+
search_filter = Filter(
190+
name="TrialComponentName",
191+
operator=Operator.EQUALS,
192+
value=trial_component_name,
193+
)
194+
search_expression = SearchExpression(filters=[search_filter])
195+
196+
def search():
197+
return list(
198+
_TrialComponent.search(
199+
search_expression=search_expression,
200+
max_results=1, # TrialComponentName is unique in an account
201+
sagemaker_session=sagemaker_session,
202+
)
203+
)[0]
204+
205+
try:
206+
tc_search_res = retry_with_backoff(search, 4)
207+
from sagemaker.experiments.run import RUN_TC_TAG
208+
209+
if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags:
210+
return False
211+
return True
212+
except Exception as ex: # pylint: disable=broad-except
213+
logging.warning(
214+
"Failed to inspect the type of the trial component (%s), due to (%s)",
215+
trial_component_name,
216+
str(ex),
217+
)
218+
return False

0 commit comments

Comments
 (0)