54
54
55
55
def _load_artifact_by_source_uri (
56
56
source_uri : str ,
57
- artifact_type : str ,
58
57
sagemaker_session : Session ,
59
58
source_types_to_match : Optional [List [str ]] = None ,
59
+ artifact_type : Optional [str ] = None ,
60
60
) -> Optional [ArtifactSummary ]:
61
61
"""Load lineage artifact by source uri
62
62
63
63
Arguments:
64
64
source_uri (str): The s3 uri used for uploading transfomred model artifacts.
65
- artifact_type (str): The type of the lineage artifact.
66
- source_types_to_match (Optional[List[str]]): A list of source type values to match against
67
- the artifact's source types. If provided, the artifact's source types must match this
68
- list.
69
65
sagemaker_session (Session): Session object which manages interactions
70
66
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
71
67
function creates one using the default AWS configuration chain.
68
+ source_types_to_match (Optional[List[str]]): A list of source type values to match against
69
+ the artifact's source types. If provided, the artifact's source types must match this
70
+ list.
71
+ artifact_type (Optional[str]): The type of the lineage artifact.
72
72
73
73
Returns:
74
74
ArtifactSummary: The Artifact Summary for the provided S3 URI.
75
75
"""
76
76
artifacts = Artifact .list (source_uri = source_uri , sagemaker_session = sagemaker_session )
77
77
for artifact_summary in artifacts :
78
- if artifact_summary .artifact_type == artifact_type :
78
+ if artifact_type is None or artifact_summary .artifact_type == artifact_type :
79
79
if source_types_to_match :
80
80
if artifact_summary .source .source_types is not None :
81
81
artifact_source_types = [
@@ -109,7 +109,9 @@ def _poll_lineage_artifact(
109
109
logger .info ("Polling lineage artifact for model data in %s" , s3_uri )
110
110
start_time = time .time ()
111
111
while time .time () - start_time < LINEAGE_POLLER_MAX_TIMEOUT_SECS :
112
- result = _load_artifact_by_source_uri (s3_uri , artifact_type , sagemaker_session )
112
+ result = _load_artifact_by_source_uri (
113
+ s3_uri , sagemaker_session , artifact_type = artifact_type
114
+ )
113
115
if result is not None :
114
116
return result
115
117
time .sleep (LINEAGE_POLLER_INTERVAL_SECS )
@@ -124,12 +126,12 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
124
126
Returns:
125
127
str: Description of what the input string is identified as.
126
128
"""
127
- mlflow_rub_id_pattern = MLFLOW_RUN_ID_REGEX
129
+ mlflow_run_id_pattern = MLFLOW_RUN_ID_REGEX
128
130
mlflow_registry_id_pattern = MLFLOW_REGISTRY_PATH_REGEX
129
131
sagemaker_arn_pattern = MODEL_PACKAGE_ARN_REGEX
130
132
s3_pattern = S3_PATH_REGEX
131
133
132
- if re .match (mlflow_rub_id_pattern , mlflow_model_path ):
134
+ if re .match (mlflow_run_id_pattern , mlflow_model_path ):
133
135
return MLFLOW_RUN_ID
134
136
if re .match (mlflow_registry_id_pattern , mlflow_model_path ):
135
137
return MLFLOW_REGISTRY_PATH
@@ -146,12 +148,14 @@ def _get_mlflow_model_path_type(mlflow_model_path: str) -> str:
146
148
def _create_mlflow_model_path_lineage_artifact (
147
149
mlflow_model_path : str ,
148
150
sagemaker_session : Session ,
151
+ source_types_to_match : Optional [List [str ]] = None ,
149
152
) -> Optional [Artifact ]:
150
153
"""Creates a lineage artifact for the given MLflow model path.
151
154
152
155
Args:
153
156
mlflow_model_path (str): The path to the MLflow model.
154
157
sagemaker_session (Session): The SageMaker session object.
158
+ source_types_to_match (Optional[List[str]]): Artifact source types.
155
159
156
160
Returns:
157
161
Optional[Artifact]: The created lineage artifact, or None if an error occurred.
@@ -161,8 +165,17 @@ def _create_mlflow_model_path_lineage_artifact(
161
165
model_builder_input_model_data_type = _artifact_name ,
162
166
)
163
167
try :
168
+ source_types = [dict (SourceIdType = "Custom" , Value = "ModelBuilderInputModelData" )]
169
+ if source_types_to_match :
170
+ source_types += [
171
+ dict (SourceIdType = "Custom" , Value = source_type )
172
+ for source_type in source_types_to_match
173
+ if source_type != "ModelBuilderInputModelData"
174
+ ]
175
+
164
176
return Artifact .create (
165
177
source_uri = mlflow_model_path ,
178
+ source_types = source_types ,
166
179
artifact_type = MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE ,
167
180
artifact_name = _artifact_name ,
168
181
properties = properties ,
@@ -178,37 +191,38 @@ def _create_mlflow_model_path_lineage_artifact(
178
191
179
192
def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact (
180
193
mlflow_model_path : str ,
181
- tracking_server_arn : str ,
182
194
sagemaker_session : Session ,
195
+ tracking_server_arn : Optional [str ] = None ,
183
196
) -> Optional [Union [Artifact , ArtifactSummary ]]:
184
197
"""Retrieves an existing artifact for the given MLflow model path or
185
198
186
199
creates a new one if it doesn't exist.
187
200
188
201
Args:
189
202
mlflow_model_path (str): The path to the MLflow model.
190
- tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
191
203
sagemaker_session (Session): Session object which manages interactions
192
204
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
193
205
function creates one using the default AWS configuration chain.
194
-
206
+ tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
195
207
196
208
Returns:
197
209
Optional[Union[Artifact, ArtifactSummary]]: The existing or newly created artifact,
198
210
or None if an error occurred.
199
211
"""
200
- match = re .match (TRACKING_SERVER_ARN_REGEX , tracking_server_arn )
201
- mlflow_tracking_server_name = match .group (4 )
202
- describe_result = sagemaker_session .sagemaker_client .describe_mlflow_tracking_server (
203
- MlflowTrackingServerName = mlflow_tracking_server_name
204
- )
205
- tracking_server_creation_time = describe_result ["CreationTime" ].strftime (
206
- TRACKING_SERVER_CREATION_TIME_FORMAT
207
- )
208
- source_types_to_match = [tracking_server_arn , tracking_server_creation_time ]
212
+ source_types_to_match = ["ModelBuilderInputModelData" ]
213
+ input_type = _get_mlflow_model_path_type (mlflow_model_path )
214
+ if tracking_server_arn and input_type in [MLFLOW_RUN_ID , MLFLOW_REGISTRY_PATH ]:
215
+ match = re .match (TRACKING_SERVER_ARN_REGEX , tracking_server_arn )
216
+ mlflow_tracking_server_name = match .group (4 )
217
+ describe_result = sagemaker_session .sagemaker_client .describe_mlflow_tracking_server (
218
+ MlflowTrackingServerName = mlflow_tracking_server_name
219
+ )
220
+ tracking_server_creation_time = describe_result ["CreationTime" ].strftime (
221
+ TRACKING_SERVER_CREATION_TIME_FORMAT
222
+ )
223
+ source_types_to_match += [tracking_server_arn , tracking_server_creation_time ]
209
224
_loaded_artifact = _load_artifact_by_source_uri (
210
225
mlflow_model_path ,
211
- MODEL_BUILDER_MLFLOW_MODEL_PATH_LINEAGE_ARTIFACT_TYPE ,
212
226
sagemaker_session ,
213
227
source_types_to_match ,
214
228
)
@@ -217,6 +231,7 @@ def _retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact(
217
231
return _create_mlflow_model_path_lineage_artifact (
218
232
mlflow_model_path ,
219
233
sagemaker_session ,
234
+ source_types_to_match ,
220
235
)
221
236
222
237
@@ -261,18 +276,18 @@ def _add_association_between_artifacts(
261
276
def _maintain_lineage_tracking_for_mlflow_model (
262
277
mlflow_model_path : str ,
263
278
s3_upload_path : str ,
264
- tracking_server_arn : Optional [str ],
265
279
sagemaker_session : Session ,
280
+ tracking_server_arn : Optional [str ] = None ,
266
281
) -> None :
267
282
"""Maintains lineage tracking for an MLflow model by creating or retrieving artifacts.
268
283
269
284
Args:
270
285
mlflow_model_path (str): The path to the MLflow model.
271
286
s3_upload_path (str): The S3 path where the transformed model data is uploaded.
272
- tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
273
287
sagemaker_session (Session): Session object which manages interactions
274
288
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
275
289
function creates one using the default AWS configuration chain.
290
+ tracking_server_arn (Optional[str]): The MLflow tracking server ARN.
276
291
"""
277
292
artifact_for_transformed_model_data = _poll_lineage_artifact (
278
293
s3_uri = s3_upload_path ,
@@ -283,8 +298,8 @@ def _maintain_lineage_tracking_for_mlflow_model(
283
298
mlflow_model_artifact = (
284
299
_retrieve_and_create_if_not_exist_mlflow_model_path_lineage_artifact (
285
300
mlflow_model_path = mlflow_model_path ,
286
- tracking_server_arn = tracking_server_arn ,
287
301
sagemaker_session = sagemaker_session ,
302
+ tracking_server_arn = tracking_server_arn ,
288
303
)
289
304
)
290
305
if mlflow_model_artifact :
0 commit comments