Skip to content

Commit 509652c

Browse files
grenmesterJacky Lee
authored andcommitted
fix: lineage tracking bug (aws#1447)
* fix: lineage bug * fix: lineage * fix: add validation for tracking ARN input with MLflow input type * fix: bug * fix: unit tests * fix: mock * fix: args --------- Co-authored-by: Jacky Lee <[email protected]>
1 parent 58d912b commit 509652c

File tree

3 files changed

+56
-37
lines changed

3 files changed

+56
-37
lines changed

src/sagemaker/serve/builder/model_builder.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,8 @@ def _model_builder_register_wrapper(self, *args, **kwargs):
508508
_maintain_lineage_tracking_for_mlflow_model(
509509
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
510510
s3_upload_path=self.s3_upload_path,
511-
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
512511
sagemaker_session=self.sagemaker_session,
512+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
513513
)
514514
return new_model_package
515515

@@ -579,8 +579,8 @@ def _model_builder_deploy_wrapper(
579579
_maintain_lineage_tracking_for_mlflow_model(
580580
mlflow_model_path=self.model_metadata[MLFLOW_MODEL_PATH],
581581
s3_upload_path=self.s3_upload_path,
582-
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
583582
sagemaker_session=self.sagemaker_session,
583+
tracking_server_arn=self.model_metadata.get(MLFLOW_TRACKING_ARN),
584584
)
585585
return predictor
586586

src/sagemaker/serve/utils/lineage_utils.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,28 @@
5454

5555
def _load_artifact_by_source_uri(
5656
source_uri: str,
57-
artifact_type: str,
5857
sagemaker_session: Session,
5958
source_types_to_match: Optional[List[str]] = None,
59+
artifact_type: Optional[str] = None,
6060
) -> Optional[ArtifactSummary]:
6161
"""Load lineage artifact by source uri
6262
6363
Arguments:
6464
source_uri (str): The s3 uri used for uploading transfomred model artifacts.
65-
artifact_type (str): The type of the lineage artifact.
66-
source_types_to_match (Optional[List[str]]): A list of source type values to match against
67-
the artifact's source types. If provided, the artifact's source types must match this
68-
list.
6965
sagemaker_session (Session): Session object which manages interactions
7066
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
7167
function creates one using the default AWS configuration chain.
68+
source_types_to_match (Optional[List[str]]): A list of source type values to match against
69+
the artifact's source types. If provided, the artifact's source types must match this
70+
list.
71+
artifact_type (Optional[str]): The type of the lineage artifact.
7272
7373
Returns:
7474
ArtifactSummary: The Artifact Summary for the provided S3 URI.
7575
"""
7676
artifacts = Artifact.list(source_uri=source_uri, sagemaker_session=sagemaker_session)
7777
for artifact_summary in artifacts:
78-
if artifact_summary.artifact_type == artifact_type:
78+
if artifact_type is None or artifact_summary.artifact_type == artifact_type:
7979
if source_types_to_match:
8080
if artifact_summary.source.source_types is not None:
8181
artifact_source_types = [
@@ -109,7 +109,9 @@ def _poll_lineage_artifact(
109109
logger.info("Polling lineage artifact for model data in %s", s3_uri)
110110
start_time = time.time()
111111
while time.time() - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS:
112-
result = _load_artifact_by_source_uri(s3_uri, artifact_type, sagemaker_session)
112+
result = _load_artifact_by_source_uri(
113+
s3_uri, sagemaker_session, artifact_type=artifact_type
114+
)
113115
if result is not None:
114116
return result
115117
time.sleep(LINEAGE_POLLER_INTERVAL_SECS)
@@ -124,12 +126,12 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
124126
Returns:
125127
str: Description of what the input string is identified as.
126128
"""
127-
mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX
129+
mlflow_run_id_pattern = MLFLOW_RUN_ID_REGEX
128130
mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX
129131
sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX
130132
s3_pattern = S3_PATH_REGEX
131133

132-
if re.match(mlflow_rub_id_pattern, mlflow_model_path):
134+
if re.match(mlflow_run_id_pattern, mlflow_model_path):
133135
return MLFLOW_RUN_ID
134136
if re.match(mlflow_registry_id_pattern, mlflow_model_path):
135137
return MLFLOW_REGISTRY_PATH
@@ -146,12 +148,14 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
146148
def _create_mlflow_model_path_lineage_artifact(
147149
mlflow_model_path: str,
148150
sagemaker_session: Session,
151+
source_types_to_match: Optional[List[str]] = None,
149152
) -> Optional[Artifact]:
150153
"""Creates a lineage artifact for the given MLflow model path.
151154
152155
Args:
153156
mlflow_model_path (str): The path to the MLflow model.
154157
sagemaker_session (Session): The SageMaker session object.
158+
source_types_to_match (Optional[List[str]]): Artifact source types.
155159
156160
Returns:
157161
Optional[Artifact]: The created lineage artifact, or None if an error occurred.
@@ -161,8 +165,17 @@ def _create_mlflow_model_path_lineage_artifact(
161165
model_builder_input_model_data_type=_artifact_name,
162166
)
163167
try:
168+
source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")]
169+
if source_types_to_match:
170+
source_types += [
171+
dict(SourceIdType="Custom", Value=source_type)
172+
for source_type in source_types_to_match
173+
if source_type != "ModelBuilderInputModelData"
174+
]
175+
164176
return Artifact.create(
165177
source_uri=mlflow_model_path,
178+
source_types=source_types,
166179
artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
167180
artifact_name=_artifact_name,
168181
properties=properties,
@@ -178,37 +191,38 @@ def _create_mlflow_model_path_lineage_artifact(
178191

179192
def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
180193
mlflow_model_path: str,
181-
tracking_server_arn: str,
182194
sagemaker_session: Session,
195+
tracking_server_arn: Optional[str] = None,
183196
) -> Optional[Union[Artifact, ArtifactSummary]]:
184197
"""Retrieves an existing artifact for the given MLflow model path or
185198
186199
creates a new one if it doesn't exist.
187200
188201
Args:
189202
mlflow_model_path (str): The path to the MLflow model.
190-
tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
191203
sagemaker_session (Session): Session object which manages interactions
192204
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
193205
function creates one using the default AWS configuration chain.
194-
206+
tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
195207
196208
Returns:
197209
Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact,
198210
or None if an error occurred.
199211
"""
200-
match = re.match(TRACKING_SERVER_ARN_REGEX, tracking_server_arn)
201-
mlflow_tracking_server_name = match.group(4)
202-
describe_result = sagemaker_session.sagemaker_client.describe_mlflow_tracking_server(
203-
MlflowTrackingServerName=mlflow_tracking_server_name
204-
)
205-
tracking_server_creation_time = describe_result["CreationTime"].strftime(
206-
TRACKING_SERVER_CREATION_TIME_FORMAT
207-
)
208-
source_types_to_match = [tracking_server_arn, tracking_server_creation_time]
212+
source_types_to_match = ["ModelBuilderInputModelData"]
213+
input_type = _get_mlflow_model_path_type(mlflow_model_path)
214+
if tracking_server_arn and input_type in [MLFLOW_RUN_ID, MLFLOW_REGISTRY_PATH]:
215+
match = re.match(TRACKING_SERVER_ARN_REGEX, tracking_server_arn)
216+
mlflow_tracking_server_name = match.group(4)
217+
describe_result = sagemaker_session.sagemaker_client.describe_mlflow_tracking_server(
218+
MlflowTrackingServerName=mlflow_tracking_server_name
219+
)
220+
tracking_server_creation_time = describe_result["CreationTime"].strftime(
221+
TRACKING_SERVER_CREATION_TIME_FORMAT
222+
)
223+
source_types_to_match += [tracking_server_arn, tracking_server_creation_time]
209224
_loaded_artifact = _load_artifact_by_source_uri(
210225
mlflow_model_path,
211-
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
212226
sagemaker_session,
213227
source_types_to_match,
214228
)
@@ -217,6 +231,7 @@ def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
217231
return _create_mlflow_model_path_lineage_artifact(
218232
mlflow_model_path,
219233
sagemaker_session,
234+
source_types_to_match,
220235
)
221236

222237

@@ -261,18 +276,18 @@ def _add_association_between_artifacts(
261276
def _maintain_lineage_tracking_for_mlflow_model(
262277
mlflow_model_path: str,
263278
s3_upload_path: str,
264-
tracking_server_arn: Optional[str],
265279
sagemaker_session: Session,
280+
tracking_server_arn: Optional[str] = None,
266281
) -> None:
267282
"""Maintains lineage tracking for an MLflow model by creating or retrieving artifacts.
268283
269284
Args:
270285
mlflow_model_path (str): The path to the MLflow model.
271286
s3_upload_path (str): The S3 path where the transformed model data is uploaded.
272-
tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
273287
sagemaker_session (Session): Session object which manages interactions
274288
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
275289
function creates one using the default AWS configuration chain.
290+
tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
276291
"""
277292
artifact_for_transformed_model_data = _poll_lineage_artifact(
278293
s3_uri=s3_upload_path,
@@ -283,8 +298,8 @@ def _maintain_lineage_tracking_for_mlflow_model(
283298
mlflow_model_artifact = (
284299
_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
285300
mlflow_model_path=mlflow_model_path,
286-
tracking_server_arn=tracking_server_arn,
287301
sagemaker_session=sagemaker_session,
302+
tracking_server_arn=tracking_server_arn,
288303
)
289304
)
290305
if mlflow_model_artifact:

tests/unit/sagemaker/serve/utils/test_lineage_utils.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_load_artifact_by_source_uri(mock_artifact_list):
5757
mock_artifact_list.return_value = mock_artifacts
5858

5959
result = _load_artifact_by_source_uri(
60-
source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session
60+
source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value
6161
)
6262

6363
mock_artifact_list.assert_called_once_with(
@@ -79,7 +79,7 @@ def test_load_artifact_by_source_uri_no_match(mock_artifact_list):
7979
mock_artifact_list.return_value = mock_artifacts
8080

8181
result = _load_artifact_by_source_uri(
82-
source_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session
82+
source_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value
8383
)
8484

8585
mock_artifact_list.assert_called_once_with(
@@ -106,7 +106,7 @@ def test_poll_lineage_artifact_found(mock_load_artifact):
106106
assert result == mock_artifact
107107
mock_load_artifact.assert_has_calls(
108108
[
109-
call(s3_uri, LineageSourceEnum.MODEL_DATA.value, sagemaker_session),
109+
call(s3_uri, sagemaker_session, artifact_type=LineageSourceEnum.MODEL_DATA.value),
110110
]
111111
)
112112

@@ -166,6 +166,7 @@ def test_create_mlflow_model_path_lineage_artifact_success(
166166
mock_artifact_create, mock_get_mlflow_path_type
167167
):
168168
mlflow_model_path = "runs:/Ab12Cd34/my-model"
169+
mock_source_types = [dict(SourceIdType="Custom", Value="ModelBuilderInputModelData")]
169170
sagemaker_session = Mock(spec=Session)
170171
mock_artifact = Mock(spec=Artifact)
171172
mock_get_mlflow_path_type.return_value = "mlflow_run_id"
@@ -177,6 +178,7 @@ def test_create_mlflow_model_path_lineage_artifact_success(
177178
mock_get_mlflow_path_type.assert_called_once_with(mlflow_model_path)
178179
mock_artifact_create.assert_called_once_with(
179180
source_uri=mlflow_model_path,
181+
source_types=mock_source_types,
180182
artifact_type=MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
181183
artifact_name="mlflow_run_id",
182184
properties={"model_builder_input_model_data_type": "mlflow_run_id"},
@@ -233,20 +235,20 @@ def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_exi
233235
mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response
234236
sagemaker_session.sagemaker_client = mock_sagemaker_client
235237
mock_source_types_to_match = [
238+
"ModelBuilderInputModelData",
236239
mock_tracking_server_arn,
237240
mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT),
238241
]
239242
mock_artifact_summary = Mock(spec=ArtifactSummary)
240243
mock_load_artifact.return_value = mock_artifact_summary
241244

242245
result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
243-
mlflow_model_path, mock_tracking_server_arn, sagemaker_session
246+
mlflow_model_path, sagemaker_session, mock_tracking_server_arn
244247
)
245248

246249
assert result == mock_artifact_summary
247250
mock_load_artifact.assert_called_once_with(
248251
mlflow_model_path,
249-
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
250252
sagemaker_session,
251253
mock_source_types_to_match,
252254
)
@@ -269,6 +271,7 @@ def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_cre
269271
mock_sagemaker_client.describe_mlflow_tracking_server.return_value = mock_describe_response
270272
sagemaker_session.sagemaker_client = mock_sagemaker_client
271273
mock_source_types_to_match = [
274+
"ModelBuilderInputModelData",
272275
mock_tracking_server_arn,
273276
mock_creation_time.strftime(TRACKING_SERVER_CREATION_TIME_FORMAT),
274277
]
@@ -277,17 +280,18 @@ def test_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact_cre
277280
mock_create_artifact.return_value = mock_artifact
278281

279282
result = _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
280-
mlflow_model_path, mock_tracking_server_arn, sagemaker_session
283+
mlflow_model_path, sagemaker_session, mock_tracking_server_arn
281284
)
282285

283286
assert result == mock_artifact
284287
mock_load_artifact.assert_called_once_with(
285288
mlflow_model_path,
286-
MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE,
287289
sagemaker_session,
288290
mock_source_types_to_match,
289291
)
290-
mock_create_artifact.assert_called_once_with(mlflow_model_path, sagemaker_session)
292+
mock_create_artifact.assert_called_once_with(
293+
mlflow_model_path, sagemaker_session, mock_source_types_to_match
294+
)
291295

292296

293297
@patch("sagemaker.lineage.association.Association.create")
@@ -364,7 +368,7 @@ def test_maintain_lineage_tracking_for_mlflow_model_success(
364368
mock_retrieve_create_artifact.return_value = mock_mlflow_model_artifact
365369

366370
_maintain_lineage_tracking_for_mlflow_model(
367-
mlflow_model_path, s3_upload_path, mock_tracking_server_arn, sagemaker_session
371+
mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn
368372
)
369373

370374
mock_poll_artifact.assert_called_once_with(
@@ -402,7 +406,7 @@ def test_maintain_lineage_tracking_for_mlflow_model_no_model_data_artifact(
402406
mock_retrieve_create_artifact.return_value = None
403407

404408
_maintain_lineage_tracking_for_mlflow_model(
405-
mlflow_model_path, s3_upload_path, mock_tracking_server_arn, sagemaker_session
409+
mlflow_model_path, s3_upload_path, sagemaker_session, mock_tracking_server_arn
406410
)
407411

408412
mock_poll_artifact.assert_called_once_with(

0 commit comments

Comments
 (0)