18
18
import logging
19
19
import os
20
20
21
+ from sagemaker import Session
21
22
from sagemaker .experiments import trial_component
22
23
from sagemaker .utils import retry_with_backoff
23
24
24
25
TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
25
26
PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
26
- TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH "
27
+ TRANSFORM_JOB_ARN_ENV = "TRANSFORM_JOB_ARN "
27
28
MAX_RETRY_ATTEMPTS = 7
28
29
29
30
logger = logging .getLogger (__name__ )
@@ -40,7 +41,7 @@ class _EnvironmentType(enum.Enum):
40
41
class _RunEnvironment (object ):
41
42
"""Retrieves job specific data from the environment."""
42
43
43
- def __init__ (self , environment_type , source_arn ):
44
+ def __init__ (self , environment_type : _EnvironmentType , source_arn : str ):
44
45
"""Init for _RunEnvironment.
45
46
46
47
Args:
@@ -53,9 +54,9 @@ def __init__(self, environment_type, source_arn):
53
54
@classmethod
54
55
def load (
55
56
cls ,
56
- training_job_arn_env = TRAINING_JOB_ARN_ENV ,
57
- processing_job_config_path = PROCESSING_JOB_CONFIG_PATH ,
58
- transform_job_batch_var = TRANSFORM_JOB_ENV_BATCH_VAR ,
57
+ training_job_arn_env : str = TRAINING_JOB_ARN_ENV ,
58
+ processing_job_config_path : str = PROCESSING_JOB_CONFIG_PATH ,
59
+ transform_job_arn_env : str = TRANSFORM_JOB_ARN_ENV ,
59
60
):
60
61
"""Loads source arn of current job from environment.
61
62
@@ -64,8 +65,8 @@ def load(
64
65
(default: `TRAINING_JOB_ARN`).
65
66
processing_job_config_path (str): The processing job config path
66
67
(default: `/opt/ml/config/processingjobconfig.json`).
67
- transform_job_batch_var (str): The environment variable indicating if
68
- it is a transform job (default: `SAGEMAKER_BATCH `).
68
+ transform_job_arn_env (str): The environment key for transform job ARN
69
+ (default: `TRANSFORM_JOB_ARN_ENV `).
69
70
70
71
Returns:
71
72
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
@@ -78,16 +79,15 @@ def load(
78
79
environment_type = _EnvironmentType .SageMakerProcessingJob
79
80
source_arn = json .loads (open (processing_job_config_path ).read ())["ProcessingJobArn" ]
80
81
return _RunEnvironment (environment_type , source_arn )
81
- if transform_job_batch_var in os .environ and os . environ [ transform_job_batch_var ] == "true" :
82
+ if transform_job_arn_env in os .environ :
82
83
environment_type = _EnvironmentType .SageMakerTransformJob
83
- # TODO: need to figure out how to get source_arn from job env
84
- # with Transform team's help.
85
- source_arn = ""
84
+ # TODO: need to update to get source_arn from config file once Transform side ready
85
+ source_arn = os .environ .get (transform_job_arn_env )
86
86
return _RunEnvironment (environment_type , source_arn )
87
87
88
88
return None
89
89
90
- def get_trial_component (self , sagemaker_session ):
90
+ def get_trial_component (self , sagemaker_session : Session ):
91
91
"""Retrieves the trial component from the job in the environment.
92
92
93
93
Args:
@@ -99,14 +99,6 @@ def get_trial_component(self, sagemaker_session):
99
99
Returns:
100
100
_TrialComponent: The trial component created from the job. None if not found.
101
101
"""
102
- # TODO: Remove this condition check once we have a way to retrieve source ARN
103
- # from transform job env
104
- if self .environment_type == _EnvironmentType .SageMakerTransformJob :
105
- logger .error (
106
- "Currently getting the job trial component from the transform job environment "
107
- "is not supported. Returning None."
108
- )
109
- return None
110
102
111
103
def _get_trial_component ():
112
104
summaries = list (
0 commit comments