Skip to content

Commit 478b950

Browse files
aoguo64Ao Guo
authored and
Namrata Madan
committed
fix session creation in load_run() (aws#884)
Co-authored-by: Ao Guo <[email protected]>
1 parent 53cef39 commit 478b950

File tree

3 files changed

+9
-26
lines changed

3 files changed

+9
-26
lines changed

src/sagemaker/experiments/run.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,6 @@ def load_run(
795795
Returns:
796796
Run: The loaded Run object.
797797
"""
798-
sagemaker_session = sagemaker_session or _utils.default_session()
799798
environment = _RunEnvironment.load()
800799

801800
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
@@ -810,13 +809,13 @@ def load_run(
810809
run_instance = Run(
811810
experiment_name=experiment_name,
812811
run_name=run_name,
813-
sagemaker_session=sagemaker_session,
812+
sagemaker_session=sagemaker_session or _utils.default_session(),
814813
)
815814
elif _RunContext.get_current_run():
816815
run_instance = _RunContext.get_current_run()
817816
elif environment:
818817
exp_config = get_tc_and_exp_config_from_job_env(
819-
environment=environment, sagemaker_session=sagemaker_session
818+
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
820819
)
821820
run_name = Run._extract_run_name_from_tc_name(
822821
trial_component_name=exp_config[RUN_NAME],
@@ -826,7 +825,7 @@ def load_run(
826825
run_instance = Run(
827826
experiment_name=experiment_name,
828827
run_name=run_name,
829-
sagemaker_session=sagemaker_session,
828+
sagemaker_session=sagemaker_session or _utils.default_session(),
830829
)
831830
else:
832831
raise RuntimeError(

tests/integ/sagemaker/remote_function/test_decorator.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,7 @@ def test_decorator_load_run_inside_remote_function(
382382
sagemaker_session=sagemaker_session,
383383
)
384384
def train():
385-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
386-
sagemaker_session = Session(boto_session=boto_session)
387-
with load_run(sagemaker_session=sagemaker_session) as run:
385+
with load_run() as run:
388386
run.log_parameters({"p3": 3.0, "p4": 4})
389387
run.log_metric("test-job-load-log-metric", 0.1)
390388

tests/integ/sagemaker/remote_function/test_executor.py

+5-19
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import boto3
16-
import os
17-
18-
from sagemaker import Session
1915
from sagemaker.experiments.trial_component import _TrialComponent
2016
from sagemaker.remote_function import RemoteExecutor
2117
from sagemaker.remote_function.client import get_future, list_futures
@@ -85,17 +81,13 @@ def test_executor_submit_with_run_inside(
8581
sagemaker_session, dummy_container_without_error, cpu_instance_type
8682
):
8783
def square(x):
88-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
89-
sm_session = Session(boto_session=boto_session)
90-
with load_run(sagemaker_session=sm_session) as run:
84+
with load_run() as run:
9185
result = x * x
9286
run.log_metric("x", result)
9387
return result
9488

9589
def cube(x):
96-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
97-
sm_session = Session(boto_session=boto_session)
98-
with load_run(sagemaker_session=sm_session) as run:
90+
with load_run() as run:
9991
result = x * x * x
10092
run.log_metric("x", result)
10193
return result
@@ -148,17 +140,13 @@ def test_executor_submit_with_run_outside(
148140
sagemaker_session, dummy_container_without_error, cpu_instance_type
149141
):
150142
def square(x):
151-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
152-
sm_session = Session(boto_session=boto_session)
153-
with load_run(sagemaker_session=sm_session) as run:
143+
with load_run() as run:
154144
result = x * x
155145
run.log_metric("x", result)
156146
return result
157147

158148
def cube(x):
159-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
160-
sm_session = Session(boto_session=boto_session)
161-
with load_run(sagemaker_session=sm_session) as run:
149+
with load_run() as run:
162150
result = x * x * x
163151
run.log_metric("x", result)
164152
return result
@@ -209,9 +197,7 @@ def cube(x):
209197

210198
def test_executor_map_with_run(sagemaker_session, dummy_container_without_error, cpu_instance_type):
211199
def square(x):
212-
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
213-
sm_session = Session(boto_session=boto_session)
214-
with load_run(sagemaker_session=sm_session) as run:
200+
with load_run() as run:
215201
result = x * x
216202
run.log_metric("x", result)
217203
return result

0 commit comments

Comments
 (0)