|
13 | 13 | """Contains the SageMaker Experiment utility methods."""
|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
| 16 | +import logging |
16 | 17 | import os
|
17 | 18 |
|
18 | 19 | import mimetypes
|
19 | 20 | import urllib
|
20 | 21 | from functools import wraps
|
| 22 | +from typing import Optional |
21 | 23 |
|
| 24 | +from sagemaker import Session |
22 | 25 | from sagemaker.apiutils import _utils
|
| 26 | +from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType |
| 27 | +from sagemaker.experiments.trial_component import _TrialComponent |
| 28 | +from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression |
| 29 | +from sagemaker.utils import retry_with_backoff |
23 | 30 |
|
24 | 31 |
|
25 | 32 | def resolve_artifact_name(file_path):
|
@@ -93,3 +100,119 @@ def is_already_exist_error(error):
|
93 | 100 | `botocore.exceptions.ClientError`
|
94 | 101 | """
|
95 | 102 | return error["Code"] == "ValidationException" and "already exists" in error["Message"]
|
| 103 | + |
| 104 | + |
| 105 | +def get_tc_and_exp_config_from_job_env( |
| 106 | + environment: _RunEnvironment, |
| 107 | + sagemaker_session: Session, |
| 108 | +) -> dict: |
| 109 | + """Retrieve an experiment config from the job environment. |
| 110 | +
|
| 111 | + Args: |
| 112 | + environment (_RunEnvironment): The run environment object with job specific data. |
| 113 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 114 | + manages interactions with Amazon SageMaker APIs and any other |
| 115 | + AWS services needed. If not specified, one is created using the |
| 116 | + default AWS configuration chain. |
| 117 | + """ |
| 118 | + job_name = environment.source_arn.split("/")[-1] |
| 119 | + if environment.environment_type == _EnvironmentType.SageMakerTrainingJob: |
| 120 | + job_response = retry_with_backoff( |
| 121 | + callable_func=lambda: sagemaker_session.describe_training_job(job_name), |
| 122 | + num_attempts=4, |
| 123 | + ) |
| 124 | + elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob: |
| 125 | + job_response = retry_with_backoff( |
| 126 | + callable_func=lambda: sagemaker_session.describe_processing_job(job_name), |
| 127 | + num_attempts=4, |
| 128 | + ) |
| 129 | + else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob |
| 130 | + raise RuntimeError( |
| 131 | + "Failed to load the Run as loading experiment config " |
| 132 | + "from transform job environment is not currently supported. " |
| 133 | + "As a workaround, please explicitly pass in " |
| 134 | + "the experiment_name and run_name in load_run." |
| 135 | + ) |
| 136 | + |
| 137 | + job_exp_config = job_response.get("ExperimentConfig", dict()) |
| 138 | + from sagemaker.experiments.run import RUN_NAME |
| 139 | + |
| 140 | + if job_exp_config.get(RUN_NAME, None): |
| 141 | + return job_exp_config |
| 142 | + raise RuntimeError( |
| 143 | + "Not able to fetch RunName in ExperimentConfig of the sagemaker job. " |
| 144 | + "Please make sure the ExperimentConfig is correctly set." |
| 145 | + ) |
| 146 | + |
| 147 | + |
| 148 | +def verify_load_input_names( |
| 149 | + run_name: Optional[str] = None, |
| 150 | + experiment_name: Optional[str] = None, |
| 151 | +): |
| 152 | + """Verify the run_name and the experiment_name inputs in load_run. |
| 153 | +
|
| 154 | + Args: |
| 155 | + run_name (str): The run_name supplied by the user (default: None). |
| 156 | + experiment_name (str): The experiment_name supplied by the user |
| 157 | + (default: None). |
| 158 | +
|
| 159 | + Raises: |
| 160 | + ValueError: If run_name is supplied while experiment_name is not. |
| 161 | + """ |
| 162 | + if not run_name and experiment_name: |
| 163 | + logging.warning( |
| 164 | + "No run_name is supplied. Ignoring the provided experiment_name " |
| 165 | + "since it only takes effect along with run_name. " |
| 166 | + "Will load the Run object from the job environment or current Run context." |
| 167 | + ) |
| 168 | + if run_name and not experiment_name: |
| 169 | + raise ValueError( |
| 170 | + "Invalid input: experiment_name is missing when run_name is supplied. " |
| 171 | + "Please supply a valid experiment_name when the run_name is not None." |
| 172 | + ) |
| 173 | + |
| 174 | + |
| 175 | +def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool: |
| 176 | + """Check if a trial component is generated by `sagemaker.experiments.Run` |
| 177 | +
|
| 178 | + Args: |
| 179 | + trial_component_name (str): The name of the trial component. |
| 180 | + sagemaker_session (sagemaker.session.Session): Session object which |
| 181 | + manages interactions with Amazon SageMaker APIs and any other |
| 182 | + AWS services needed. If not specified, one is created using the |
| 183 | + default AWS configuration chain. |
| 184 | +
|
| 185 | + Returns: |
| 186 | + bool: Indicate whether the trial component is created by |
| 187 | + `sagemaker.experiments.Run` or not. |
| 188 | + """ |
| 189 | + search_filter = Filter( |
| 190 | + name="TrialComponentName", |
| 191 | + operator=Operator.EQUALS, |
| 192 | + value=trial_component_name, |
| 193 | + ) |
| 194 | + search_expression = SearchExpression(filters=[search_filter]) |
| 195 | + |
| 196 | + def search(): |
| 197 | + return list( |
| 198 | + _TrialComponent.search( |
| 199 | + search_expression=search_expression, |
| 200 | + max_results=1, # TrialComponentName is unique in an account |
| 201 | + sagemaker_session=sagemaker_session, |
| 202 | + ) |
| 203 | + )[0] |
| 204 | + |
| 205 | + try: |
| 206 | + tc_search_res = retry_with_backoff(search, 4) |
| 207 | + from sagemaker.experiments.run import RUN_TC_TAG |
| 208 | + |
| 209 | + if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags: |
| 210 | + return False |
| 211 | + return True |
| 212 | + except Exception as ex: # pylint: disable=broad-except |
| 213 | + logging.warning( |
| 214 | + "Failed to inspect the type of the trial component (%s), due to (%s)", |
| 215 | + trial_component_name, |
| 216 | + str(ex), |
| 217 | + ) |
| 218 | + return False |
0 commit comments