Skip to content

Commit d524416

Browse files
author
Yi-Ting Lee
committed
startarn added to lineageQueryResult
1 parent 7359b3d commit d524416

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/sagemaker/lineage/query.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from datetime import datetime
1717
from enum import Enum
18+
from tracemalloc import start
1819
from typing import Optional, Union, List, Dict
1920

2021
from sagemaker.lineage._utils import get_resource_name_from_arn
@@ -208,6 +209,7 @@ def __init__(
208209
self,
209210
edges: List[Edge] = None,
210211
vertices: List[Vertex] = None,
212+
startarn: List[str] = None,
211213
):
212214
"""Init for LineageQueryResult.
213215
@@ -217,13 +219,17 @@ def __init__(
217219
"""
218220
self.edges = []
219221
self.vertices = []
222+
self.startarn = []
220223

221224
if edges is not None:
222225
self.edges = edges
223226

224227
if vertices is not None:
225228
self.vertices = vertices
226229

230+
if startarn is not None:
231+
self.startarn = startarn
232+
227233
def __str__(self):
228234
"""Define string representation of ``LineageQueryResult``.
229235
@@ -248,7 +254,7 @@ def __str__(self):
248254
249255
"""
250256
result_dict = vars(self)
251-
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
257+
return str({k: [str(val) for val in v] for k, v in result_dict.items()})
252258

253259
def _import_visual_modules(self):
254260
"""Import modules needed for visualization."""
@@ -417,9 +423,8 @@ def _get_vertex(self, vertex):
417423
sagemaker_session=self._session,
418424
)
419425

420-
def _convert_api_response(self, response) -> LineageQueryResult:
426+
def _convert_api_response(self, response, converted) -> LineageQueryResult:
421427
"""Convert the lineage query API response to its Python representation."""
422-
converted = LineageQueryResult()
423428
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
424429
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
425430

@@ -502,7 +507,9 @@ def query(
502507
Filters=query_filter._to_request_dict() if query_filter else {},
503508
MaxDepth=max_depth,
504509
)
505-
query_response = self._convert_api_response(query_response)
510+
# create query result for startarn info
511+
query_result = LineageQueryResult(startarn=start_arns)
512+
query_response = self._convert_api_response(query_response, query_result)
506513
query_response = self._collapse_cross_account_artifacts(query_response)
507514

508515
return query_response

0 commit comments

Comments
 (0)