Skip to content

Commit d378771

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
committed
change: Update Run init and add Run load and _RunContext (aws#707)
* change: Update Run init and add Run load Add exp name and run group name to load and address comments * Address nit comments Co-authored-by: Dewen Qi <[email protected]>
1 parent d52d42a commit d378771

17 files changed

+1549
-776
lines changed

src/sagemaker/amazon/amazon_estimator.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sagemaker.deprecations import renamed_warning
2828
from sagemaker.estimator import EstimatorBase, _TrainingJob
2929
from sagemaker.inputs import FileSystemInput, TrainingInput
30-
from sagemaker.utils import sagemaker_timestamp
30+
from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config
3131
from sagemaker.workflow.entities import PipelineVariable
3232
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3333
from sagemaker.workflow import is_pipeline_variable
@@ -255,6 +255,7 @@ def fit(
255255
"""
256256
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
257257

258+
experiment_config = check_and_get_run_experiment_config(experiment_config)
258259
self.latest_training_job = _TrainingJob.start_new(
259260
self, records, experiment_config=experiment_config
260261
)

src/sagemaker/estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
get_config_value,
8080
name_from_base,
8181
to_string,
82+
check_and_get_run_experiment_config,
8283
)
8384
from sagemaker.workflow import is_pipeline_variable
8485
from sagemaker.workflow.entities import PipelineVariable
@@ -1122,6 +1123,7 @@ def fit(
11221123
"""
11231124
self._prepare_for_training(job_name=job_name)
11241125

1126+
experiment_config = check_and_get_run_experiment_config(experiment_config)
11251127
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
11261128
self.jobs.append(self.latest_training_job)
11271129
if wait:
+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Contains the SageMaker Experiment _RunContext class."""
14+
from __future__ import absolute_import
15+
16+
from typing import TYPE_CHECKING
17+
18+
if TYPE_CHECKING:
19+
from sagemaker.experiments.run import Run
20+
21+
22+
class _RunContext:
23+
"""A static context variable to keep track of the current Run object"""
24+
25+
_context_run = None
26+
27+
@classmethod
28+
def add_run_object(cls, run: "Run"):
29+
"""Keep track of the current executing Run object
30+
31+
by adding it to a class static variable.
32+
33+
Args:
34+
run (Run): The current Run object to be tracked.
35+
"""
36+
cls._context_run = run
37+
38+
@classmethod
39+
def drop_current_run(cls) -> "Run":
40+
"""Drop the Run object tracked in the global static variable
41+
42+
as its execution finishes (its "with" block ends).
43+
44+
Return:
45+
Run: the dropped Run object.
46+
"""
47+
current_run = cls._context_run
48+
cls._context_run = None
49+
return current_run
50+
51+
@classmethod
52+
def get_current_run(cls) -> "Run":
53+
"""Return the current Run object without dropping it.
54+
55+
Return:
56+
Run: the current Run object to be returned.
57+
"""
58+
return cls._context_run

src/sagemaker/experiments/_utils.py

+15
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import mimetypes
1919
import urllib
20+
from functools import wraps
21+
2022
from sagemaker.apiutils import _utils
2123

2224

@@ -68,3 +70,16 @@ def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_
6870
"Lengths mismatch between true labels and {}: "
6971
"({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs))
7072
)
73+
74+
75+
def validate_invoked_inside_run_context(func):
76+
"""A Decorator to force the decorated method called under Run context."""
77+
78+
@wraps(func)
79+
def wrapper(*args, **kwargs):
80+
self_instance = args[0]
81+
if not self_instance._inside_load_context and not self_instance._inside_init_context:
82+
raise RuntimeError("This method should be called inside context of 'with' statement.")
83+
return func(*args, **kwargs)
84+
85+
return wrapper

0 commit comments

Comments
 (0)