Skip to content

Commit b7963ed

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
committed
change: Add integ test for using Run in Transform Job (aws#749)
Co-authored-by: Dewen Qi <[email protected]>
1 parent e5edfb6 commit b7963ed

File tree

16 files changed

+318
-53
lines changed

16 files changed

+318
-53
lines changed

src/sagemaker/experiments/_environment.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
2525
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
26+
TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH"
2627
MAX_RETRY_ATTEMPTS = 7
2728

2829
logger = logging.getLogger(__name__)
@@ -33,6 +34,7 @@ class _EnvironmentType(enum.Enum):
3334

3435
SageMakerTrainingJob = 1
3536
SageMakerProcessingJob = 2
37+
SageMakerTransformJob = 3
3638

3739

3840
class _RunEnvironment(object):
@@ -53,6 +55,7 @@ def load(
5355
cls,
5456
training_job_arn_env=TRAINING_JOB_ARN_ENV,
5557
processing_job_config_path=PROCESSING_JOB_CONFIG_PATH,
58+
transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR,
5659
):
5760
"""Loads source arn of current job from environment.
5861
@@ -61,11 +64,12 @@ def load(
6164
(default: `TRAINING_JOB_ARN`).
6265
processing_job_config_path (str): The processing job config path
6366
(default: `/opt/ml/config/processingjobconfig.json`).
67+
transform_job_batch_var (str): The environment variable indicating if
68+
it is a transform job (default: `SAGEMAKER_BATCH`).
6469
6570
Returns:
6671
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
6772
"""
68-
# TODO: enable to determine transform job env
6973
if training_job_arn_env in os.environ:
7074
environment_type = _EnvironmentType.SageMakerTrainingJob
7175
source_arn = os.environ.get(training_job_arn_env)
@@ -74,6 +78,13 @@ def load(
7478
environment_type = _EnvironmentType.SageMakerProcessingJob
7579
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
7680
return _RunEnvironment(environment_type, source_arn)
81+
if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true":
82+
environment_type = _EnvironmentType.SageMakerTransformJob
83+
# TODO: need to figure out how to get source_arn from job env
84+
# with Transform team's help.
85+
source_arn = ""
86+
return _RunEnvironment(environment_type, source_arn)
87+
7788
return None
7889

7990
def get_trial_component(self, sagemaker_session):
@@ -92,7 +103,7 @@ def get_trial_component(self, sagemaker_session):
92103
def _get_trial_component():
93104
summaries = list(
94105
trial_component._TrialComponent.list(
95-
source_arn=self.source_arn, sagemaker_session=sagemaker_session
106+
source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session
96107
)
97108
)
98109
if summaries:

src/sagemaker/experiments/run.py

+45-31
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,18 @@ def init(
158158
sagemaker_session=sagemaker_session,
159159
)
160160

161-
run_tc = _TrialComponent._load_or_create(
161+
run_tc, is_existed = _TrialComponent._load_or_create(
162162
trial_component_name=trial_component_name,
163163
display_name=run_display_name,
164164
tags=Run._append_run_tc_label_to_tags(tags),
165165
sagemaker_session=sagemaker_session,
166166
)
167+
if is_existed:
168+
logger.warning(
169+
"The Run (%s) under experiment (%s) already exists. Loading it.",
170+
run_name,
171+
experiment_name,
172+
)
167173

168174
trial.add_trial_component(run_tc)
169175

@@ -184,12 +190,12 @@ def load(
184190
experiment_name: Optional[str] = None,
185191
sagemaker_session: Optional["Session"] = None,
186192
):
187-
"""Load a Run Trial Component by the run name or from the job environment.
193+
"""Load a Run by the run name or from the job environment.
188194
189195
Args:
190196
run_name (str): The name of the Run to be loaded (default: None).
191197
If it is None, the `RunName` in the `ExperimentConfig` of the job will be
192-
fetched to load the Run Trial Component.
198+
fetched to load the Run.
193199
experiment_name (str): The name of the Experiment that the to be loaded Run
194200
is associated with (default: None).
195201
Note: the experiment_name must be supplied along with a valid run_name.
@@ -253,7 +259,7 @@ def _experiment_config(self):
253259

254260
@validate_invoked_inside_run_context
255261
def log_parameter(self, name, value):
256-
"""Record a single parameter value for this run trial component.
262+
"""Record a single parameter value for this run.
257263
258264
Overwrites any previous value recorded for the specified parameter name.
259265
@@ -266,7 +272,7 @@ def log_parameter(self, name, value):
266272

267273
@validate_invoked_inside_run_context
268274
def log_parameters(self, parameters):
269-
"""Record a collection of parameter values for this run trial component.
275+
"""Record a collection of parameter values for this run.
270276
271277
Args:
272278
parameters (dict[str, str or numbers.Number]): The parameters to record.
@@ -280,7 +286,7 @@ def log_parameters(self, parameters):
280286

281287
@validate_invoked_inside_run_context
282288
def log_metric(self, name, value, timestamp=None, step=None):
283-
"""Record a custom scalar metric value for this run trial component.
289+
"""Record a custom scalar metric value for this run.
284290
285291
Note:
286292
1. This method is for manual custom metrics, for automatic metrics see the
@@ -313,9 +319,9 @@ def log_precision_recall(
313319
"""Create and log a precision recall graph artifact for Studio UI to render.
314320
315321
The artifact is stored in S3 and represented as a lineage artifact
316-
with an association with the run trial component.
322+
with an association with the run.
317323
318-
You can view the artifact in the charts tab of the Trial Component UI.
324+
You can view the artifact in the UI.
319325
If your job is created by a pipeline execution you can view the artifact
320326
by selecting the corresponding step in the pipelines UI.
321327
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -329,7 +335,7 @@ def log_precision_recall(
329335
positive_label (str or int): Label of the positive class (default: None).
330336
title (str): Title of the graph (default: None).
331337
is_output (bool): Determines direction of association to the
332-
trial component. Defaults to True (output artifact).
338+
run. Defaults to True (output artifact).
333339
If set to False then represented as input association.
334340
no_skill (int): The precision threshold under which the classifier cannot discriminate
335341
between the classes and would predict a random class or a constant class in
@@ -378,9 +384,9 @@ def log_roc_curve(
378384
"""Create and log a receiver operating characteristic (ROC curve) artifact.
379385
380386
The artifact is stored in S3 and represented as a lineage artifact
381-
with an association with the run trial component.
387+
with an association with the run.
382388
383-
You can view the artifact in the charts tab of the Trial Component UI.
389+
You can view the artifact in the UI.
384390
If your job is created by a pipeline execution you can view the artifact
385391
by selecting the corresponding step in the pipelines UI.
386392
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -393,7 +399,7 @@ def log_roc_curve(
393399
y_score (list or array): Estimated/predicted probabilities.
394400
title (str): Title of the graph (default: None).
395401
is_output (bool): Determines direction of association to the
396-
trial component. Defaults to True (output artifact).
402+
run. Defaults to True (output artifact).
397403
If set to False then represented as input association.
398404
"""
399405
verify_length_of_true_and_predicted(
@@ -430,9 +436,9 @@ def log_confusion_matrix(
430436
"""Create and log a confusion matrix artifact.
431437
432438
The artifact is stored in S3 and represented as a lineage artifact
433-
with an association with the run trial component.
439+
with an association with the run.
434440
435-
You can view the artifact in the charts tab of the Trial Component UI.
441+
You can view the artifact in the UI.
436442
If your job is created by a pipeline execution you can view the
437443
artifact by selecting the corresponding step in the pipelines UI.
438444
See also `SageMaker Pipelines <https://aws.amazon.com/sagemaker/pipelines/>`_
@@ -444,7 +450,7 @@ def log_confusion_matrix(
444450
y_pred (list or array): Predicted labels.
445451
title (str): Title of the graph (default: None).
446452
is_output (bool): Determines direction of association to the
447-
trial component. Defaults to True (output artifact).
453+
run. Defaults to True (output artifact).
448454
If set to False then represented as input association.
449455
"""
450456
verify_length_of_true_and_predicted(
@@ -468,7 +474,7 @@ def log_confusion_matrix(
468474

469475
@validate_invoked_inside_run_context
470476
def log_output(self, name, value, media_type=None):
471-
"""Record a single output artifact for this run trial component.
477+
"""Record a single output artifact for this run.
472478
473479
Overwrites any previous value recorded for the specified output name.
474480
@@ -484,7 +490,7 @@ def log_output(self, name, value, media_type=None):
484490

485491
@validate_invoked_inside_run_context
486492
def log_input(self, name, value, media_type=None):
487-
"""Record a single input artifact for this run trial component.
493+
"""Record a single input artifact for this run.
488494
489495
Overwrites any previous value recorded for the specified input name.
490496
@@ -500,7 +506,7 @@ def log_input(self, name, value, media_type=None):
500506

501507
@validate_invoked_inside_run_context
502508
def log_artifact_file(self, file_path, name=None, media_type=None, is_output=True):
503-
"""Upload a file to s3 and store it as an input/output artifact in this trial component.
509+
"""Upload a file to s3 and store it as an input/output artifact in this run.
504510
505511
Args:
506512
file_path (str): The path of the local file to upload.
@@ -509,7 +515,7 @@ def log_artifact_file(self, file_path, name=None, media_type=None, is_output=Tru
509515
If not specified, this library will attempt to infer the media type
510516
from the file extension of `file_path`.
511517
is_output (bool): Determines direction of association to the
512-
trial component. Defaults to True (output artifact).
518+
run. Defaults to True (output artifact).
513519
If set to False then represented as input association.
514520
"""
515521
self._verify_trial_component_artifacts_length(is_output)
@@ -527,7 +533,7 @@ def log_artifact_file(self, file_path, name=None, media_type=None, is_output=Tru
527533

528534
@validate_invoked_inside_run_context
529535
def log_artifact_directory(self, directory, media_type=None, is_output=True):
530-
"""Upload files under directory to s3 and log as artifacts in this trial component.
536+
"""Upload files under directory to s3 and log as artifacts in this run.
531537
532538
The file name is used as the artifact name
533539
@@ -537,7 +543,7 @@ def log_artifact_directory(self, directory, media_type=None, is_output=True):
537543
If not specified, this library will attempt to infer the media type
538544
from the file extension of `file_path`.
539545
is_output (bool): Determines direction of association to the
540-
trial component. Defaults to True (output artifact).
546+
run. Defaults to True (output artifact).
541547
If set to False then represented as input association.
542548
"""
543549
for dir_file in os.listdir(directory):
@@ -549,7 +555,7 @@ def log_artifact_directory(self, directory, media_type=None, is_output=True):
549555

550556
@validate_invoked_inside_run_context
551557
def log_lineage_artifact(self, file_path, name=None, media_type=None, is_output=True):
552-
"""Upload a file to S3 and creates a lineage Artifact associated with this trial component.
558+
"""Upload a file to S3 and creates a lineage Artifact associated with this run.
553559
554560
Args:
555561
file_path (str): The path of the local file to upload.
@@ -558,7 +564,7 @@ def log_lineage_artifact(self, file_path, name=None, media_type=None, is_output=
558564
If not specified, this library will attempt to infer the media type
559565
from the file extension of `file_path`.
560566
is_output (bool): Determines direction of association to the
561-
trial component. Defaults to True (output artifact).
567+
run. Defaults to True (output artifact).
562568
If set to False then represented as input association.
563569
"""
564570
media_type = media_type or guess_media_type(file_path)
@@ -588,11 +594,11 @@ def list(
588594
"""Return a list of `Run` objects matching the given criteria.
589595
590596
Args:
591-
experiment_name (str): Only trial components related to the specified experiment
597+
experiment_name (str): Only Run objects related to the specified experiment
592598
are returned.
593-
created_before (datetime.datetime): Return trial components created before this instant
599+
created_before (datetime.datetime): Return Run objects created before this instant
594600
(default: None).
595-
created_after (datetime.datetime): Return trial components created after this instant
601+
created_after (datetime.datetime): Return Run objects created after this instant
596602
(default: None).
597603
sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
598604
(default: 'CreationTime').
@@ -601,7 +607,7 @@ def list(
601607
manages interactions with Amazon SageMaker APIs and any other
602608
AWS services needed. If not specified, one is created using the
603609
default AWS configuration chain.
604-
max_results (int): maximum number of trial components to retrieve (default: None).
610+
max_results (int): maximum number of Run objects to retrieve (default: None).
605611
next_token (str): token for next page of results (default: None).
606612
607613
Returns:
@@ -715,7 +721,7 @@ def _verify_trial_component_artifacts_length(self, is_output):
715721
Raises:
716722
ValueError: If the length of trial component artifacts exceeds the limit.
717723
"""
718-
err_msg_template = "Cannot add more than {} {}_artifacts under run trial_component"
724+
err_msg_template = "Cannot add more than {} {}_artifacts under run"
719725
if is_output:
720726
if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
721727
raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
@@ -777,11 +783,19 @@ def _get_tc_and_exp_config_from_job_env(
777783
callable_func=lambda: sagemaker_session.describe_training_job(job_name),
778784
num_attempts=4,
779785
)
780-
else: # environment.environment_type == _EnvironmentType.SageMakerProcessingJob
786+
elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob:
781787
job_response = retry_with_backoff(
782788
callable_func=lambda: sagemaker_session.describe_processing_job(job_name),
783789
num_attempts=4,
784790
)
791+
else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob
792+
raise RuntimeError(
793+
"Failed to load the Run as loading experiment config "
794+
"from transform job environment is not currently supported. "
795+
"As a workaround, please explicitly pass in "
796+
"the experiment_name and run_name in Run.load."
797+
)
798+
785799
job_exp_config = job_response.get("ExperimentConfig", dict())
786800
if job_exp_config.get(RUN_NAME, None):
787801
# The run with RunName has been created outside of the job env.
@@ -867,7 +881,7 @@ def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) ->
867881
return tags
868882

869883
def __enter__(self):
870-
"""Updates the start time of the tracked trial component.
884+
"""Updates the start time of the run.
871885
872886
Returns:
873887
object: self.
@@ -897,7 +911,7 @@ def __enter__(self):
897911
return self
898912

899913
def __exit__(self, exc_type, exc_value, exc_traceback):
900-
"""Updates the end time of the tracked trial component.
914+
"""Updates the end time of the run.
901915
902916
Args:
903917
exc_type (str): The exception type.

src/sagemaker/experiments/trial_component.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,18 @@ def _load_or_create(
324324
325325
Returns:
326326
experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
327+
bool: A boolean variable indicating whether the trail component already exists
327328
"""
328329
sagemaker_client = sagemaker_session.sagemaker_client
330+
is_existed = False
329331
try:
330332
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
333+
is_existed = True
331334
except sagemaker_client.exceptions.ResourceNotFound:
332335
run_tc = _TrialComponent.create(
333336
trial_component_name=trial_component_name,
334337
display_name=display_name,
335338
tags=tags,
336339
sagemaker_session=sagemaker_session,
337340
)
338-
return run_tc
341+
return run_tc, is_existed

src/sagemaker/processing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@
3333
from sagemaker.job import _Job
3434
from sagemaker.local import LocalSession
3535
from sagemaker.network import NetworkConfig
36-
from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
36+
from sagemaker.utils import (
37+
base_name_from_image,
38+
get_config_value,
39+
name_from_base,
40+
check_and_get_run_experiment_config,
41+
)
3742
from sagemaker.session import Session
3843
from sagemaker.workflow import is_pipeline_variable
3944
from sagemaker.workflow.functions import Join
@@ -203,6 +208,7 @@ def run(
203208
outputs=outputs,
204209
)
205210

211+
experiment_config = check_and_get_run_experiment_config(experiment_config)
206212
self.latest_job = ProcessingJob.start_new(
207213
processor=self,
208214
inputs=normalized_inputs,
@@ -605,6 +611,7 @@ def run(
605611
kms_key=kms_key,
606612
)
607613

614+
experiment_config = check_and_get_run_experiment_config(experiment_config)
608615
self.latest_job = ProcessingJob.start_new(
609616
processor=self,
610617
inputs=normalized_inputs,

0 commit comments

Comments
 (0)