Skip to content

Commit 9125eb9

Browse files
author
Yi-Ting Lee
committed
startarn added
1 parent 60904d5 commit 9125eb9

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/sagemaker/lineage/query.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from datetime import datetime
1717
from enum import Enum
18+
from tracemalloc import start
1819
from typing import Optional, Union, List, Dict
1920

2021
from sagemaker.lineage._utils import get_resource_name_from_arn
@@ -295,6 +296,7 @@ def __init__(
295296
self,
296297
edges: List[Edge] = None,
297298
vertices: List[Vertex] = None,
299+
startarn: List[str] = None,
298300
):
299301
"""Init for LineageQueryResult.
300302
@@ -304,13 +306,17 @@ def __init__(
304306
"""
305307
self.edges = []
306308
self.vertices = []
309+
self.startarn = []
307310

308311
if edges is not None:
309312
self.edges = edges
310313

311314
if vertices is not None:
312315
self.vertices = vertices
313316

317+
if startarn is not None:
318+
self.startarn = startarn
319+
314320
def __str__(self):
315321
"""Define string representation of ``LineageQueryResult``.
316322
@@ -335,7 +341,7 @@ def __str__(self):
335341
336342
"""
337343
result_dict = vars(self)
338-
return str({k: [vars(val) for val in v] for k, v in result_dict.items()})
344+
return str({k: [str(val) for val in v] for k, v in result_dict.items()})
339345

340346
def _covert_vertices_to_tuples(self):
341347
"""Convert vertices to tuple format for visualizer."""
@@ -456,9 +462,8 @@ def _get_vertex(self, vertex):
456462
sagemaker_session=self._session,
457463
)
458464

459-
def _convert_api_response(self, response) -> LineageQueryResult:
465+
def _convert_api_response(self, response, converted) -> LineageQueryResult:
460466
"""Convert the lineage query API response to its Python representation."""
461-
converted = LineageQueryResult()
462467
converted.edges = [self._get_edge(edge) for edge in response["Edges"]]
463468
converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]]
464469

@@ -541,7 +546,9 @@ def query(
541546
Filters=query_filter._to_request_dict() if query_filter else {},
542547
MaxDepth=max_depth,
543548
)
544-
query_response = self._convert_api_response(query_response)
549+
# create query result for startarn info
550+
query_result = LineageQueryResult(startarn=start_arns)
551+
query_response = self._convert_api_response(query_response, query_result)
545552
query_response = self._collapse_cross_account_artifacts(query_response)
546553

547554
return query_response

0 commit comments

Comments
 (0)