Skip to content

Commit acafbb2

Browse files
qidewenwhenJoseJuan98
authored andcommitted
change: Enable load_run without name args in Transform env (aws#3585)
1 parent ee77880 commit acafbb2

File tree

10 files changed

+74
-224
lines changed

10 files changed

+74
-224
lines changed

src/sagemaker/experiments/_environment.py

+12-20
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import logging
1919
import os
2020

21+
from sagemaker import Session
2122
from sagemaker.experiments import trial_component
2223
from sagemaker.utils import retry_with_backoff
2324

2425
TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
2526
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
26-
TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH"
27+
TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN"
2728
MAX_RETRY_ATTEMPTS = 7
2829

2930
logger = logging.getLogger(__name__)
@@ -40,7 +41,7 @@ class _EnvironmentType(enum.Enum):
4041
class _RunEnvironment(object):
4142
"""Retrieves job specific data from the environment."""
4243

43-
def __init__(self, environment_type, source_arn):
44+
def __init__(self, environment_type: _EnvironmentType, source_arn: str):
4445
"""Init for _RunEnvironment.
4546
4647
Args:
@@ -53,9 +54,9 @@ def __init__(self, environment_type, source_arn):
5354
@classmethod
5455
def load(
5556
cls,
56-
training_job_arn_env=TRAINING_JOB_ARN_ENV,
57-
processing_job_config_path=PROCESSING_JOB_CONFIG_PATH,
58-
transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR,
57+
training_job_arn_env: str = TRAINING_JOB_ARN_ENV,
58+
processing_job_config_path: str = PROCESSING_JOB_CONFIG_PATH,
59+
transform_job_arn_env: str = TRANSFORM_JOB_ARN_ENV,
5960
):
6061
"""Loads source arn of current job from environment.
6162
@@ -64,8 +65,8 @@ def load(
6465
(default: `TRAINING_JOB_ARN`).
6566
processing_job_config_path (str): The processing job config path
6667
(default: `/opt/ml/config/processingjobconfig.json`).
67-
transform_job_batch_var (str): The environment variable indicating if
68-
it is a transform job (default: `SAGEMAKER_BATCH`).
68+
transform_job_arn_env (str): The environment key for transform job ARN
69+
(default: `TRANSFORM_JOB_ARN_ENV`).
6970
7071
Returns:
7172
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
@@ -78,16 +79,15 @@ def load(
7879
environment_type = _EnvironmentType.SageMakerProcessingJob
7980
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
8081
return _RunEnvironment(environment_type, source_arn)
81-
if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true":
82+
if transform_job_arn_env in os.environ:
8283
environment_type = _EnvironmentType.SageMakerTransformJob
83-
# TODO: need to figure out how to get source_arn from job env
84-
# with Transform team's help.
85-
source_arn = ""
84+
# TODO: need to update to get source_arn from config file once Transform side ready
85+
source_arn = os.environ.get(transform_job_arn_env)
8686
return _RunEnvironment(environment_type, source_arn)
8787

8888
return None
8989

90-
def get_trial_component(self, sagemaker_session):
90+
def get_trial_component(self, sagemaker_session: Session):
9191
"""Retrieves the trial component from the job in the environment.
9292
9393
Args:
@@ -99,14 +99,6 @@ def get_trial_component(self, sagemaker_session):
9999
Returns:
100100
_TrialComponent: The trial component created from the job. None if not found.
101101
"""
102-
# TODO: Remove this condition check once we have a way to retrieve source ARN
103-
# from transform job env
104-
if self.environment_type == _EnvironmentType.SageMakerTransformJob:
105-
logger.error(
106-
"Currently getting the job trial component from the transform job environment "
107-
"is not supported. Returning None."
108-
)
109-
return None
110102

111103
def _get_trial_component():
112104
summaries = list(

src/sagemaker/experiments/_metrics.py

-80
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
import datetime
17-
import json
1817
import logging
1918
import os
2019
import time
@@ -35,85 +34,6 @@
3534
logger = logging.getLogger(__name__)
3635

3736

38-
# TODO: remove this _SageMakerFileMetricsWriter class
39-
# when _MetricsManager is fully ready
40-
class _SageMakerFileMetricsWriter(object):
41-
"""Write metric data to file."""
42-
43-
def __init__(self, metrics_file_path=None):
44-
"""Construct a `_SageMakerFileMetricsWriter` object"""
45-
self._metrics_file_path = metrics_file_path
46-
self._file = None
47-
self._closed = False
48-
49-
def log_metric(self, metric_name, value, timestamp=None, step=None):
50-
"""Write a metric to file.
51-
52-
Args:
53-
metric_name (str): The name of the metric.
54-
value (float): The value of the metric.
55-
timestamp (datetime.datetime): Timestamp of the metric.
56-
If not specified, the current UTC time will be used.
57-
step (int): Iteration number of the metric (default: None).
58-
59-
Raises:
60-
SageMakerMetricsWriterException: If the metrics file is closed.
61-
AttributeError: If file has been initialized and the writer hasn't been closed.
62-
"""
63-
raw_metric_data = _RawMetricData(
64-
metric_name=metric_name, value=value, timestamp=timestamp, step=step
65-
)
66-
try:
67-
logger.debug("Writing metric: %s", raw_metric_data)
68-
self._file.write(json.dumps(raw_metric_data.to_record()))
69-
self._file.write("\n")
70-
except AttributeError as attr_err:
71-
if self._closed:
72-
raise SageMakerMetricsWriterException("log_metric called on a closed writer")
73-
if not self._file:
74-
self._file = open(self._get_metrics_file_path(), "a", buffering=1)
75-
self._file.write(json.dumps(raw_metric_data.to_record()))
76-
self._file.write("\n")
77-
else:
78-
raise attr_err
79-
80-
def close(self):
81-
"""Closes the metric file."""
82-
if not self._closed and self._file:
83-
self._file.close()
84-
self._file = None # invalidate reference, causing subsequent log_metric to fail.
85-
self._closed = True
86-
87-
def __enter__(self):
88-
"""Return self"""
89-
return self
90-
91-
def __exit__(self, exc_type, exc_value, exc_traceback):
92-
"""Execute self.close()"""
93-
self.close()
94-
95-
def __del__(self):
96-
"""Execute self.close()"""
97-
self.close()
98-
99-
def _get_metrics_file_path(self):
100-
"""Get file path to store metrics"""
101-
pid_filename = "{}.json".format(str(os.getpid()))
102-
metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename)
103-
logger.debug("metrics_file_path = %s", metrics_file_path)
104-
return metrics_file_path
105-
106-
107-
class SageMakerMetricsWriterException(Exception):
108-
"""SageMakerMetricsWriterException"""
109-
110-
def __init__(self, message, errors=None):
111-
"""Construct a `SageMakerMetricsWriterException` instance"""
112-
super().__init__(message)
113-
if errors:
114-
self.errors = errors
115-
116-
11737
class _RawMetricData(object):
11838
"""A Raw Metric Data Object"""
11939

src/sagemaker/experiments/_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,9 @@ def get_tc_and_exp_config_from_job_env(
127127
num_attempts=4,
128128
)
129129
else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob
130-
raise RuntimeError(
131-
"Failed to load the Run as loading experiment config "
132-
"from transform job environment is not currently supported. "
133-
"As a workaround, please explicitly pass in "
134-
"the experiment_name and run_name in load_run."
130+
job_response = retry_with_backoff(
131+
callable_func=lambda: sagemaker_session.describe_transform_job(job_name),
132+
num_attempts=4,
135133
)
136134

137135
job_exp_config = job_response.get("ExperimentConfig", dict())

src/sagemaker/experiments/run.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,18 @@ def __init__(
120120
estimator.fit(job_name="my-job") # Create a training job
121121
122122
In order to reuse an existing run to log extra data, ``load_run`` is recommended.
123+
For example, instead of the ``Run`` constructor, the ``load_run`` is recommended to use
124+
in a job script to load the existing run created before the job launch.
125+
Otherwise, a new run may be created each time you launch a job.
126+
123127
The code snippet below displays how to load the run initialized above
124128
in a custom training job script, where no ``run_name`` or ``experiment_name``
125129
is presented as they are automatically retrieved from the experiment config
126130
in the job environment.
127131
128-
Note:
129-
Instead of the ``Run`` constructor, the ``load_run`` is recommended to use
130-
in a job script to load the existing run created before the job launch.
131-
Otherwise, a new run may be created each time you launch a job.
132-
133132
.. code:: python
134133
135-
with load_run() as run:
134+
with load_run(sagemaker_session=sagemaker_session) as run:
136135
run.log_metric(...)
137136
...
138137

tests/data/experiment/inference.py

+3
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def model_fn(model_dir):
6868
run.log_parameters({"p3": 3.0, "p4": 4.0})
6969
run.log_metric("test-job-load-log-metric", 0.1)
7070

71+
with load_run(sagemaker_session=sagemaker_session) as run:
72+
run.log_parameters({"p5": 5.0, "p6": 6})
73+
7174
model_file = "xgboost-model"
7275
booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
7376
return booster

tests/integ/sagemaker/experiments/test_metrics.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ def verify_metrics():
3030
sagemaker_session=sagemaker_session,
3131
)
3232
metrics = updated_tc.metrics
33-
# TODO: revert to len(metrics) == 2 once backend fix reaches prod
34-
assert len(metrics) > 0
33+
assert len(metrics) == 2
3534
assert list(filter(lambda x: x.metric_name == "test-x-step", metrics))
3635
assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics))
3736

tests/integ/sagemaker/experiments/test_run.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -482,10 +482,9 @@ def test_run_from_transform_job(
482482
):
483483
# Notes:
484484
# 1. The 1st Run (run) created locally
485-
# 2. In the inference script running in a transform job, load the 1st Run
486-
# via explicitly passing the experiment_name and run_name of the 1st Run
487-
# TODO: once we're able to retrieve exp config from the transform job env,
488-
# we should expand this test and add the load_run() without explicitly supplying the names
485+
# 2. In the inference script running in a transform job, load the 1st Run twice and log data
486+
# 1) via explicitly passing the experiment_name and run_name of the 1st Run
487+
# 2) use load_run() without explicitly supplying the names
489488
# 3. All data are logged in the Run either locally or in the transform job
490489
exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
491490
xgb_model_data_s3 = sagemaker_session.upload_data(
@@ -537,6 +536,7 @@ def test_run_from_transform_job(
537536
content_type="text/libsvm",
538537
split_type="Line",
539538
wait=True,
539+
logs=False,
540540
job_name=f"transform-job-{name()}",
541541
)
542542

@@ -549,7 +549,7 @@ def test_run_from_transform_job(
549549
experiment_name=run.experiment_name, run_name=run.run_name
550550
)
551551
_check_run_from_job_result(
552-
tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
552+
tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False, has_extra_load=True
553553
)
554554

555555

@@ -718,8 +718,7 @@ def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True
718718
assert "s3://Input" == tc.input_artifacts[artifact_name].value
719719
assert not tc.input_artifacts[artifact_name].media_type
720720

721-
# TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod
722-
assert len(tc.metrics) > 0
721+
assert len(tc.metrics) == 1
723722
metric_summary = tc.metrics[0]
724723
assert metric_summary.metric_name == metric_name
725724
assert metric_summary.max == 9.0
@@ -733,9 +732,7 @@ def validate_tc_updated_in_init():
733732
assert tc.status.primary_status == _TrialComponentStatusType.Completed.value
734733
assert tc.parameters["p1"] == 1.0
735734
assert tc.parameters["p2"] == 2.0
736-
# TODO: revert to assert len(tc.metrics) == 5 once
737-
# backend fix hits prod
738-
assert len(tc.metrics) > 0
735+
assert len(tc.metrics) == 5
739736
for metric_summary in tc.metrics:
740737
# metrics deletion is not supported at this point
741738
# so its count would accumulate

tests/unit/sagemaker/experiments/test_environment.py

+10-15
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pytest
2222

2323
from sagemaker.experiments import _environment
24+
from sagemaker.experiments._environment import TRANSFORM_JOB_ARN_ENV, TRAINING_JOB_ARN_ENV
2425
from sagemaker.utils import retry_with_backoff
2526

2627

@@ -33,22 +34,22 @@ def tempdir():
3334

3435
@pytest.fixture
3536
def training_job_env():
36-
old_value = os.environ.get("TRAINING_JOB_ARN")
37-
os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe"
37+
old_value = os.environ.get(TRAINING_JOB_ARN_ENV)
38+
os.environ[TRAINING_JOB_ARN_ENV] = "arn:1234aBcDe"
3839
yield os.environ
39-
del os.environ["TRAINING_JOB_ARN"]
40+
del os.environ[TRAINING_JOB_ARN_ENV]
4041
if old_value:
41-
os.environ["TRAINING_JOB_ARN"] = old_value
42+
os.environ[TRAINING_JOB_ARN_ENV] = old_value
4243

4344

4445
@pytest.fixture
4546
def transform_job_env():
46-
old_value = os.environ.get("SAGEMAKER_BATCH")
47-
os.environ["SAGEMAKER_BATCH"] = "true"
47+
old_value = os.environ.get(TRANSFORM_JOB_ARN_ENV)
48+
os.environ[TRANSFORM_JOB_ARN_ENV] = "arn:1234aBcDe"
4849
yield os.environ
49-
del os.environ["SAGEMAKER_BATCH"]
50+
del os.environ[TRANSFORM_JOB_ARN_ENV]
5051
if old_value:
51-
os.environ["SAGEMAKER_BATCH"] = old_value
52+
os.environ[TRANSFORM_JOB_ARN_ENV] = old_value
5253

5354

5455
def test_processing_job_environment(tempdir):
@@ -70,8 +71,7 @@ def test_training_job_environment(training_job_env):
7071
def test_transform_job_environment(transform_job_env):
7172
environment = _environment._RunEnvironment.load()
7273
assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type
73-
# TODO: update if we figure out how to get source_arn from the transform job
74-
assert not environment.source_arn
74+
assert "arn:1234aBcDe" == environment.source_arn
7575

7676

7777
def test_no_environment():
@@ -100,8 +100,3 @@ def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_j
100100
client.list_trial_components.side_effect = Exception("Failed test")
101101
environment = _environment._RunEnvironment.load()
102102
assert environment.get_trial_component(sagemaker_session) is None
103-
104-
105-
def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session):
106-
environment = _environment._RunEnvironment.load()
107-
assert environment.get_trial_component(sagemaker_session) is None

0 commit comments

Comments
 (0)