Skip to content

Commit c8d1ef3

Browse files
qidewenwhenDewen Qi
and
Dewen Qi
committed
change: Update integ tests for Exp Plus M1 changes (aws#741)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 0cf6d1d commit c8d1ef3

16 files changed

+763
-641
lines changed

src/sagemaker/experiments/_api_types.py

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Contains API objects for SageMaker experiments."""
1414
from __future__ import absolute_import
1515

16+
import enum
1617
import numbers
1718

1819
from sagemaker.apiutils import _base_types
@@ -102,6 +103,14 @@ def __init__(self, value=None, media_type=None, **kwargs):
102103
super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs)
103104

104105

106+
class _TrialComponentStatusType(enum.Enum):
107+
"""The type of trial component status"""
108+
109+
InProgress = "InProgress"
110+
Completed = "Completed"
111+
Failed = "Failed"
112+
113+
105114
class TrialComponentStatus(_base_types.ApiObject):
106115
"""Status of the trial component.
107116

src/sagemaker/experiments/_environment.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
logger = logging.getLogger(__name__)
2929

3030

31-
class EnvironmentType(enum.Enum):
31+
class _EnvironmentType(enum.Enum):
3232
"""SageMaker jobs which data can be pulled from the environment."""
3333

3434
SageMakerTrainingJob = 1
@@ -42,7 +42,7 @@ def __init__(self, environment_type, source_arn):
4242
"""Init for _RunEnvironment.
4343
4444
Args:
45-
environment_type (EnvironmentType): The environment type.
45+
environment_type (_EnvironmentType): The environment type.
4646
source_arn (str): The ARN of the current job.
4747
"""
4848
self.environment_type = environment_type
@@ -65,12 +65,13 @@ def load(
6565
Returns:
6666
_RunEnvironment: Job data loaded from the environment. None if config does not exist.
6767
"""
68+
# TODO: enable to determine transform job env
6869
if training_job_arn_env in os.environ:
69-
environment_type = EnvironmentType.SageMakerTrainingJob
70+
environment_type = _EnvironmentType.SageMakerTrainingJob
7071
source_arn = os.environ.get(training_job_arn_env)
7172
return _RunEnvironment(environment_type, source_arn)
7273
if os.path.exists(processing_job_config_path):
73-
environment_type = EnvironmentType.SageMakerProcessingJob
74+
environment_type = _EnvironmentType.SageMakerProcessingJob
7475
source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
7576
return _RunEnvironment(environment_type, source_arn)
7677
return None

src/sagemaker/experiments/_helper.py

+27-93
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import botocore
2121

22+
from sagemaker.experiments._utils import is_already_exist_error
23+
2224
logger = logging.getLogger(__name__)
2325

2426

@@ -171,12 +173,20 @@ def create_artifact(self, sagemaker_session):
171173
if self.etag:
172174
source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag})
173175

174-
response = sagemaker_session.sagemaker_client.create_artifact(
175-
ArtifactName=self.name,
176-
ArtifactType=self.artifact_type,
177-
Source={"SourceUri": self.source_uri, "SourceTypes": source_ids},
178-
)
179-
self.artifact_arn = response["ArtifactArn"]
176+
try:
177+
response = sagemaker_session.sagemaker_client.create_artifact(
178+
ArtifactName=self.name,
179+
ArtifactType=self.artifact_type,
180+
Source={"SourceUri": self.source_uri, "SourceTypes": source_ids},
181+
)
182+
self.artifact_arn = response["ArtifactArn"]
183+
except botocore.exceptions.ClientError as err:
184+
err_info = err.response["Error"]
185+
if not is_already_exist_error(err_info):
186+
raise
187+
logger.warning(
188+
"Skip creating the artifact since it already exists: %s", err_info["Message"]
189+
)
180190

181191
def add_association(self, sagemaker_session):
182192
"""Associate the artifact with a source/destination ARN (e.g. trial component arn)
@@ -191,9 +201,17 @@ def add_association(self, sagemaker_session):
191201
# if the trial component (job) is the source then it produced the artifact,
192202
# otherwise the artifact contributed to the trial component (job)
193203
association_edge_type = "Produced" if self.source_arn else "ContributedTo"
194-
sagemaker_session.sagemaker_client.add_association(
195-
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type
196-
)
204+
try:
205+
sagemaker_session.sagemaker_client.add_association(
206+
SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type
207+
)
208+
except botocore.exceptions.ClientError as err:
209+
err_info = err.response["Error"]
210+
if not is_already_exist_error(err_info):
211+
raise
212+
logger.warning(
213+
"Skip associating since the association already exists: %s", err_info["Message"]
214+
)
197215

198216

199217
class _LineageArtifactTracker(object):
@@ -246,87 +264,3 @@ def save(self):
246264
for artifact in self.artifacts:
247265
artifact.create_artifact(self.sagemaker_session)
248266
artifact.add_association(self.sagemaker_session)
249-
250-
251-
class _ArtifactConverter(object):
252-
"""Converts data to easily consumed by Studio."""
253-
254-
@classmethod
255-
def convert_dict_to_fields(cls, values):
256-
"""Converts a dictionary to list of field types.
257-
258-
Args:
259-
values (dict): The values of the dictionary.
260-
261-
Returns:
262-
dict: Dictionary of fields.
263-
"""
264-
fields = []
265-
for key in values:
266-
fields.append({"name": key, "type": "string"})
267-
return fields
268-
269-
@classmethod
270-
def convert_data_frame_to_values(cls, data_frame):
271-
"""Converts a pandas data frame to a dictionary in the table artifact format.
272-
273-
Args:
274-
data_frame (DataFrame): The pandas data frame to convert.
275-
276-
Returns:
277-
dict: dictionary of values in the format needed to log the artifact.
278-
"""
279-
df_dict = data_frame.to_dict()
280-
new_df = {}
281-
for key in df_dict:
282-
col_value = df_dict[key]
283-
values = []
284-
285-
for row_key in col_value:
286-
values.append(col_value[row_key])
287-
288-
new_df[key] = values
289-
290-
return new_df
291-
292-
@classmethod
293-
def convert_data_frame_to_fields(cls, data_frame):
294-
"""Converts a dataframe to a dictionary describing the type of fields.
295-
296-
Args:
297-
data_frame(DataFrame): The data frame to convert.
298-
299-
Returns:
300-
dict: Dictionary of fields.
301-
"""
302-
fields = []
303-
304-
for key in data_frame:
305-
col_type = data_frame.dtypes[key]
306-
fields.append(
307-
{"name": key, "type": _ArtifactConverter.convert_df_type_to_simple_type(col_type)}
308-
)
309-
return fields
310-
311-
@classmethod
312-
def convert_df_type_to_simple_type(cls, data_frame_type):
313-
"""Converts a dataframe type to a type for rendering a table in Studio.
314-
315-
Args:
316-
data_frame_type (str): The pandas type.
317-
318-
Returns:
319-
str: The type of the table field.
320-
"""
321-
322-
type_pairs = [
323-
("datetime", "datetime"),
324-
("float", "number"),
325-
("int", "number"),
326-
("uint", "number"),
327-
("boolean", "boolean"),
328-
]
329-
for pair in type_pairs:
330-
if str(data_frame_type).lower().startswith(pair[0]):
331-
return pair[1]
332-
return "string"

src/sagemaker/experiments/_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,13 @@ def wrapper(*args, **kwargs):
8383
return func(*args, **kwargs)
8484

8585
return wrapper
86+
87+
88+
def is_already_exist_error(error):
89+
"""Check if the error indicates resource already exists
90+
91+
Args:
92+
error (dict): The "Error" field in the response of the
93+
`botocore.exceptions.ClientError`
94+
"""
95+
return error["Code"] == "ValidationException" and "already exists" in error["Message"]

0 commit comments

Comments
 (0)