From e363696e969e96646291cfe7f5905a61e70439cf Mon Sep 17 00:00:00 2001 From: Sachin Mysore Satish Date: Thu, 4 Mar 2021 17:14:44 -0800 Subject: [PATCH] Add Type annotations for lineage --- src/sagemaker/lineage/_utils.py | 4 +- src/sagemaker/lineage/action.py | 90 +++++++++--------- src/sagemaker/lineage/artifact.py | 136 +++++++++++++++------------ src/sagemaker/lineage/association.py | 46 ++++----- src/sagemaker/lineage/context.py | 81 ++++++++-------- src/sagemaker/lineage/visualizer.py | 114 +++++++++++++--------- 6 files changed, 263 insertions(+), 208 deletions(-) diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py index c71d38f209..78be1a66e9 100644 --- a/src/sagemaker/lineage/_utils.py +++ b/src/sagemaker/lineage/_utils.py @@ -23,7 +23,9 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None) destination_arn is provided. """ association_summaries = association.Association.list( - source_arn=source_arn, destination_arn=destination_arn, sagemaker_session=sagemaker_session + source_arn=source_arn, + destination_arn=destination_arn, + sagemaker_session=sagemaker_session, ) for association_summary in association_summaries: diff --git a/src/sagemaker/lineage/action.py b/src/sagemaker/lineage/action.py index a8fef655a2..67ba6d5db0 100644 --- a/src/sagemaker/lineage/action.py +++ b/src/sagemaker/lineage/action.py @@ -13,8 +13,13 @@ """This module contains code to create and manage SageMaker ``Actions``.""" from __future__ import absolute_import +from typing import Optional, Iterator +from datetime import datetime + +from sagemaker import Session from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types, _utils +from sagemaker.lineage._api_types import ActionSource, ActionSummary class Action(_base_types.Record): @@ -53,24 +58,24 @@ class Action(_base_types.Record): last_modified_by (obj): Contextual info on which account created the action. """ - action_arn = None - action_name = None - action_type = None - description = None - status = None - source = None - properties = None - properties_to_remove = None - tags = None - creation_time = None - created_by = None - last_modified_time = None - last_modified_by = None - - _boto_create_method = "create_action" - _boto_load_method = "describe_action" - _boto_update_method = "update_action" - _boto_delete_method = "delete_action" + action_arn: str = None + action_name: str = None + action_type: str = None + description: str = None + status: str = None + source: ActionSource = None + properties: dict = None + properties_to_remove: list = None + tags: list = None + creation_time: datetime = None + created_by: str = None + last_modified_time: datetime = None + last_modified_by: str = None + + _boto_create_method: str = "create_action" + _boto_load_method: str = "describe_action" + _boto_update_method: str = "update_action" + _boto_delete_method: str = "delete_action" _boto_update_members = [ "action_name", @@ -84,7 +89,7 @@ class Action(_base_types.Record): _custom_boto_types = {"source": (_api_types.ActionSource, False)} - def save(self): + def save(self) -> "Action": """Save the state of this Action to SageMaker. Returns: @@ -92,7 +97,7 @@ def save(self): """ return self._invoke_api(self._boto_update_method, self._boto_update_members) - def delete(self, disassociate=False): + def delete(self, disassociate: bool = False): """Delete the action. Args: @@ -104,13 +109,14 @@ def delete(self, disassociate=False): source_arn=self.action_arn, sagemaker_session=self.sagemaker_session ) _utils._disassociate( - destination_arn=self.action_arn, sagemaker_session=self.sagemaker_session + destination_arn=self.action_arn, + sagemaker_session=self.sagemaker_session, ) self._invoke_api(self._boto_delete_method, self._boto_delete_members) @classmethod - def load(cls, action_name, sagemaker_session=None): + def load(cls, action_name: str, sagemaker_session: Session = None) -> "Action": """Load an existing action and return an ``Action`` object representing it. Args: @@ -154,16 +160,16 @@ def set_tags(self, tags=None): @classmethod def create( cls, - action_name=None, - source_uri=None, - source_type=None, - action_type=None, - description=None, - status=None, - properties=None, - tags=None, - sagemaker_session=None, - ): + action_name: str = None, + source_uri: str = None, + source_type: str = None, + action_type: str = None, + description: str = None, + status: str = None, + properties: dict = None, + tags: dict = None, + sagemaker_session: Session = None, + ) -> "Action": """Create an action and return an ``Action`` object representing it. Args: @@ -198,16 +204,16 @@ def create( @classmethod def list( cls, - source_uri=None, - action_type=None, - created_after=None, - created_before=None, - sort_by=None, - sort_order=None, - sagemaker_session=None, - max_results=None, - next_token=None, - ): + source_uri: Optional[str] = None, + action_type: Optional[str] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + sagemaker_session: Session = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, + ) -> Iterator[ActionSummary]: """Return a list of action summaries. Args: diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py index f7d5de1f0b..122d177608 100644 --- a/src/sagemaker/lineage/artifact.py +++ b/src/sagemaker/lineage/artifact.py @@ -16,8 +16,12 @@ import logging import math +from datetime import datetime +from typing import Iterator, Union, Any, Optional + from sagemaker.apiutils import _base_types, _utils from sagemaker.lineage import _api_types +from sagemaker.lineage._api_types import ArtifactSource, ArtifactSummary from sagemaker.lineage._utils import get_module, _disassociate from sagemaker.lineage.association import Association @@ -54,31 +58,38 @@ class Artifact(_base_types.Record): tags (List[dict[str, str]]): A list of tags to associate with the artifact. creation_time (datetime): When the artifact was created. created_by (obj): Contextual info on which account created the artifact. + last_modified_time (datetime): When the artifact was last modified. + last_modified_by (obj): Contextual info on which account created the artifact. """ - artifact_arn = None - artifact_name = None - artifact_type = None - source = None - properties = None - tags = None - creation_time = None - created_by = None - last_modified_time = None - last_modified_by = None - - _boto_create_method = "create_artifact" - _boto_load_method = "describe_artifact" - _boto_update_method = "update_artifact" - _boto_delete_method = "delete_artifact" - - _boto_update_members = ["artifact_arn", "artifact_name", "properties", "properties_to_remove"] + artifact_arn: str = None + artifact_name: str = None + artifact_type: str = None + source: ArtifactSource = None + properties: dict = None + tags: list = None + creation_time: datetime = None + created_by: str = None + last_modified_time: datetime = None + last_modified_by: str = None + + _boto_create_method: str = "create_artifact" + _boto_load_method: str = "describe_artifact" + _boto_update_method: str = "update_artifact" + _boto_delete_method: str = "delete_artifact" + + _boto_update_members = [ + "artifact_arn", + "artifact_name", + "properties", + "properties_to_remove", + ] _boto_delete_members = ["artifact_arn"] _custom_boto_types = {"source": (_api_types.ArtifactSource, False)} - def save(self): + def save(self) -> "Artifact": """Save the state of this Artifact to SageMaker. Note that this method must be run from a SageMaker context such as Studio or a training job @@ -89,7 +100,7 @@ def save(self): """ return self._invoke_api(self._boto_update_method, self._boto_update_members) - def delete(self, disassociate=False): + def delete(self, disassociate: bool = False): """Delete the artifact object. Args: @@ -98,12 +109,13 @@ def delete(self, disassociate=False): if disassociate: _disassociate(source_arn=self.artifact_arn, sagemaker_session=self.sagemaker_session) _disassociate( - destination_arn=self.artifact_arn, sagemaker_session=self.sagemaker_session + destination_arn=self.artifact_arn, + sagemaker_session=self.sagemaker_session, ) self._invoke_api(self._boto_delete_method, self._boto_delete_members) @classmethod - def load(cls, artifact_arn, sagemaker_session=None): + def load(cls, artifact_arn: str, sagemaker_session=None) -> "Artifact": """Load an existing artifact and return an ``Artifact`` object representing it. Args: @@ -123,7 +135,7 @@ def load(cls, artifact_arn, sagemaker_session=None): ) return artifact - def downstream_trials(self, sagemaker_session=None): + def downstream_trials(self, sagemaker_session=None) -> list: """Retrieve all trial runs which that use this artifact. Args: @@ -135,10 +147,10 @@ def downstream_trials(self, sagemaker_session=None): """ # don't specify destination type because for Trial Components it could be one of # SageMaker[TrainingJob|ProcessingJob|TransformJob|ExperimentTrialComponent] - outgoing_associations = Association.list( + outgoing_associations: Iterator = Association.list( source_arn=self.artifact_arn, sagemaker_session=sagemaker_session ) - trial_component_arns = list(map(lambda x: x.destination_arn, outgoing_associations)) + trial_component_arns: list = list(map(lambda x: x.destination_arn, outgoing_associations)) if not trial_component_arns: # no outgoing associations for this artifact @@ -147,25 +159,25 @@ def downstream_trials(self, sagemaker_session=None): get_module("smexperiments") from smexperiments import trial_component, search_expression - max_search_by_arn = 60 + max_search_by_arn: int = 60 num_search_batches = math.ceil(len(trial_component_arns) % max_search_by_arn) - trial_components = [] + trial_components: list = [] sagemaker_session = sagemaker_session or _utils.default_session() sagemaker_client = sagemaker_session.sagemaker_client for i in range(num_search_batches): - start = i * max_search_by_arn - end = start + max_search_by_arn - arn_batch = trial_component_arns[start:end] - se = self._get_search_expression(arn_batch, search_expression) - search_result = trial_component.TrialComponent.search( + start: int = i * max_search_by_arn + end: int = start + max_search_by_arn + arn_batch: list = trial_component_arns[start:end] + se: Any = self._get_search_expression(arn_batch, search_expression) + search_result: Any = trial_component.TrialComponent.search( search_expression=se, sagemaker_boto_client=sagemaker_client ) - trial_components = trial_components + list(search_result) + trial_components: list = trial_components + list(search_result) - trials = set() + trials: set = set() for tc in list(trial_components): for parent in tc.parents: @@ -173,7 +185,7 @@ def downstream_trials(self, sagemaker_session=None): return list(trials) - def _get_search_expression(self, arns, search_expression): + def _get_search_expression(self, arns: list, search_expression: object) -> object: """Convert a set of arns to a search expression. Args: @@ -183,14 +195,14 @@ def _get_search_expression(self, arns, search_expression): Returns: search_expression (obj): Arns converted to a Trial Component search expression. """ - max_arn_per_filter = 3 - num_filters = math.ceil(len(arns) / max_arn_per_filter) - filters = [] + max_arn_per_filter: int = 3 + num_filters: Union[float, int] = math.ceil(len(arns) / max_arn_per_filter) + filters: list = [] for i in range(num_filters): - start = i * max_arn_per_filter - end = i + max_arn_per_filter - batch_arns = arns[start:end] + start: int = i * max_arn_per_filter + end: int = i + max_arn_per_filter + batch_arns: list = arns[start:end] search_filter = search_expression.Filter( name="TrialComponentArn", operator=search_expression.Operator.EQUALS, @@ -230,14 +242,14 @@ def set_tags(self, tags=None): @classmethod def create( cls, - artifact_name=None, - source_uri=None, - source_types=None, - artifact_type=None, - properties=None, - tags=None, + artifact_name: Optional[str] = None, + source_uri: Optional[str] = None, + source_types: Optional[list] = None, + artifact_type: Optional[str] = None, + properties: Optional[dict] = None, + tags: Optional[dict] = None, sagemaker_session=None, - ): + ) -> "Artifact": """Create an artifact and return an ``Artifact`` object representing it. Args: @@ -268,16 +280,16 @@ def create( @classmethod def list( cls, - source_uri=None, - artifact_type=None, - created_before=None, - created_after=None, - sort_by=None, - sort_order=None, - max_results=None, - next_token=None, + source_uri: Optional[str] = None, + artifact_type: Optional[str] = None, + created_before: Optional[datetime] = None, + created_after: Optional[datetime] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, sagemaker_session=None, - ): + ) -> Iterator[ArtifactSummary]: """Return a list of artifact summaries. Args: @@ -324,19 +336,19 @@ class ModelArtifact(Artifact): to otherentities. """ - def endpoints(self): + def endpoints(self) -> list: """Given a model artifact, get all associated endpoint context. Returns: [AssociationSummary]: A list of associations repesenting the endpoints using the model. """ - endpoint_development_actions = Association.list( + endpoint_development_actions: Iterator = Association.list( source_arn=self.artifact_arn, destination_type="Action", sagemaker_session=self.sagemaker_session, ) - endpoint_context_list = [ + endpoint_context_list: list = [ endpoint_context_associations for endpoint_development_action in endpoint_development_actions for endpoint_context_associations in Association.list( @@ -355,16 +367,16 @@ class DatasetArtifact(Artifact): connect to related entities. """ - def trained_models(self): + def trained_models(self) -> list: """Given a dataset artifact, get associated trained models. Returns: list(Association): List of Contexts representing model artifacts. """ - trial_components = Association.list( + trial_components: Iterator = Association.list( source_arn=self.artifact_arn, sagemaker_session=self.sagemaker_session ) - result = [] + result: list = [] for trial_component in trial_components: if "experiment-trial-component" in trial_component.destination_arn: models = Association.list( diff --git a/src/sagemaker/lineage/association.py b/src/sagemaker/lineage/association.py index 4470d056c1..a0b1ce7e85 100644 --- a/src/sagemaker/lineage/association.py +++ b/src/sagemaker/lineage/association.py @@ -13,8 +13,12 @@ """This module contains code to create and manage SageMaker ``Artifact``.""" from __future__ import absolute_import +from typing import Optional, Iterator +from datetime import datetime + from sagemaker.apiutils import _base_types from sagemaker.lineage import _api_types +from sagemaker.lineage._api_types import AssociationSummary class Association(_base_types.Record): @@ -43,13 +47,13 @@ class Association(_base_types.Record): association_type (str): the type of the association. """ - source_arn = None - destination_arn = None + source_arn: str = None + destination_arn: str = None - _boto_create_method = "add_association" - _boto_delete_method = "delete_association" + _boto_create_method: str = "add_association" + _boto_delete_method: str = "delete_association" - _custom_boto_types = {} + _custom_boto_types: dict = {} _boto_delete_members = [ "source_arn", @@ -85,11 +89,11 @@ def set_tags(self, tags=None): @classmethod def create( cls, - source_arn, - destination_arn, - association_type=None, + source_arn: str, + destination_arn: str, + association_type: str = None, sagemaker_session=None, - ): + ) -> "Association": """Add an association and return an ``Association`` object representing it. Args: @@ -116,19 +120,19 @@ def create( @classmethod def list( cls, - source_arn=None, - destination_arn=None, - source_type=None, - destination_type=None, - association_type=None, - created_after=None, - created_before=None, - sort_by=None, - sort_order=None, - max_results=None, - next_token=None, + source_arn: str = None, + destination_arn: str = None, + source_type: str = None, + destination_type: str = None, + association_type: str = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, sagemaker_session=None, - ): + ) -> Iterator[AssociationSummary]: """Return a list of context summaries. Args: diff --git a/src/sagemaker/lineage/context.py b/src/sagemaker/lineage/context.py index 03300224c1..f19af66a33 100644 --- a/src/sagemaker/lineage/context.py +++ b/src/sagemaker/lineage/context.py @@ -13,12 +13,16 @@ """This module contains code to create and manage SageMaker ``Context``.""" from __future__ import absolute_import +from datetime import datetime +from typing import Iterator, Optional + from sagemaker.apiutils import _base_types from sagemaker.lineage import ( _api_types, _utils, association, ) +from sagemaker.lineage._api_types import ContextSummary class Context(_base_types.Record): @@ -38,20 +42,20 @@ class Context(_base_types.Record): last_modified_by (obj): Contextual info on which account created the context. """ - context_arn = None - context_name = None - context_type = None - properties = None - tags = None - creation_time = None - created_by = None - last_modified_time = None - last_modified_by = None - - _boto_load_method = "describe_context" - _boto_create_method = "create_context" - _boto_update_method = "update_context" - _boto_delete_method = "delete_context" + context_arn: str = None + context_name: str = None + context_type: str = None + properties: dict = None + tags: list = None + creation_time: datetime = None + created_by: str = None + last_modified_time: datetime = None + last_modified_by: str = None + + _boto_load_method: str = "describe_context" + _boto_create_method: str = "create_context" + _boto_update_method: str = "update_context" + _boto_delete_method: str = "delete_context" _custom_boto_types = { "source": (_api_types.ContextSource, False), @@ -65,7 +69,7 @@ class Context(_base_types.Record): ] _boto_delete_members = ["context_name"] - def save(self): + def save(self) -> "Context": """Save the state of this Context to SageMaker. Returns: @@ -73,7 +77,7 @@ def save(self): """ return self._invoke_api(self._boto_update_method, self._boto_update_members) - def delete(self, disassociate=False): + def delete(self, disassociate: bool = False): """Delete the context object. Args: @@ -87,7 +91,8 @@ def delete(self, disassociate=False): source_arn=self.context_arn, sagemaker_session=self.sagemaker_session ) _utils._disassociate( - destination_arn=self.context_arn, sagemaker_session=self.sagemaker_session + destination_arn=self.context_arn, + sagemaker_session=self.sagemaker_session, ) return self._invoke_api(self._boto_delete_method, self._boto_delete_members) @@ -114,7 +119,7 @@ def set_tags(self, tags=None): return self._set_tags(resource_arn=self.context_arn, tags=tags) @classmethod - def load(cls, context_name, sagemaker_session=None): + def load(cls, context_name: str, sagemaker_session=None) -> "Context": """Load an existing context and return an ``Context`` object representing it. Examples: @@ -155,15 +160,15 @@ def load(cls, context_name, sagemaker_session=None): @classmethod def create( cls, - context_name=None, - source_uri=None, - source_type=None, - context_type=None, - description=None, - properties=None, - tags=None, + context_name: str = None, + source_uri: str = None, + source_type: str = None, + context_type: str = None, + description: str = None, + properties: dict = None, + tags: dict = None, sagemaker_session=None, - ): + ) -> "Context": """Create a context and return a ``Context`` object representing it. Args: @@ -196,16 +201,16 @@ def create( @classmethod def list( cls, - source_uri=None, - context_type=None, - created_after=None, - created_before=None, - sort_by=None, - sort_order=None, - max_results=None, - next_token=None, + source_uri: Optional[str] = None, + context_type: Optional[str] = None, + created_after: Optional[datetime] = None, + created_before: Optional[datetime] = None, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + max_results: Optional[int] = None, + next_token: Optional[str] = None, sagemaker_session=None, - ): + ) -> Iterator[ContextSummary]: """Return a list of context summaries. Args: @@ -247,19 +252,19 @@ def list( class EndpointContext(Context): """An Amazon SageMaker endpoint context, which is part of a SageMaker lineage.""" - def models(self): + def models(self) -> list: """Get all models deployed by all endpoint versions of the endpoint. Returns: list of Associations: Associations that destination represents an endpoint's model. """ - endpoint_actions = association.Association.list( + endpoint_actions: Iterator = association.Association.list( sagemaker_session=self.sagemaker_session, source_arn=self.context_arn, destination_type="ModelDeployment", ) - model_list = [ + model_list: list = [ model for endpoint_action in endpoint_actions for model in association.Association.list( diff --git a/src/sagemaker/lineage/visualizer.py b/src/sagemaker/lineage/visualizer.py index e3faeaa491..c2bff2946e 100644 --- a/src/sagemaker/lineage/visualizer.py +++ b/src/sagemaker/lineage/visualizer.py @@ -12,9 +12,15 @@ # language governing permissions and limitations under the License. """This module contains functionality to display lineage data.""" from __future__ import absolute_import + import logging + +from typing import Optional, Any, Iterator + import pandas as pd +from pandas import DataFrame +from sagemaker.lineage._api_types import AssociationSummary from sagemaker.lineage.association import Association @@ -31,16 +37,16 @@ def __init__(self, sagemaker_session): def show( self, - trial_component_name=None, - training_job_name=None, - processing_job_name=None, - pipeline_execution_step=None, - model_package_arn=None, - endpoint_arn=None, - artifact_arn=None, - context_arn=None, - actions_arn=None, - ): + trial_component_name: Optional[str] = None, + training_job_name: Optional[str] = None, + processing_job_name: Optional[str] = None, + pipeline_execution_step: Optional[object] = None, + model_package_arn: Optional[str] = None, + endpoint_arn: Optional[str] = None, + artifact_arn: Optional[str] = None, + context_arn: Optional[str] = None, + actions_arn: Optional[str] = None, + ) -> DataFrame: """Generate a dataframe containing all incoming and outgoing lineage entities. Examples: @@ -65,7 +71,7 @@ def show( Returns: DataFrame: Pandas dataframe containing lineage associations. """ - start_arn = None + start_arn: str = None if trial_component_name: start_arn = self._get_start_arn_from_trial_component_name(trial_component_name) @@ -90,7 +96,7 @@ def show( return self._get_associations_dataframe(start_arn) - def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step): + def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step: object) -> str: """Given a pipeline exection step retrieve the arn of the lineage entity that represents it. Args: @@ -99,13 +105,14 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step): Returns: str: The arn of the lineage entity """ - start_arn = None + start_arn: str = None if not pipeline_execution_step["Metadata"]: return None - metadata = pipeline_execution_step["Metadata"] - jobs = ["TrainingJob", "ProcessingJob", "TransformJob"] + metadata: Any = pipeline_execution_step["Metadata"] + jobs: list = ["TrainingJob", "ProcessingJob", "TransformJob"] + for job in jobs: if job in metadata and metadata[job]: job_arn = metadata[job]["Arn"] @@ -117,7 +124,7 @@ def _get_start_arn_from_pipeline_execution_step(self, pipeline_execution_step): return start_arn - def _get_start_arn_from_job_arn(self, job_arn): + def _get_start_arn_from_job_arn(self, job_arn: str) -> str: """Given a job arn return the lineage entity. Args: @@ -126,16 +133,16 @@ def _get_start_arn_from_job_arn(self, job_arn): Returns: str: The arn of the job's lineage entity. """ - start_arn = None - response = self._session.sagemaker_client.list_trial_components(SourceArn=job_arn) - trial_components = response["TrialComponentSummaries"] + start_arn: str = None + response: Any = self._session.sagemaker_client.list_trial_components(SourceArn=job_arn) + trial_components: Any = response["TrialComponentSummaries"] if trial_components: start_arn = trial_components[0]["TrialComponentArn"] else: logging.warning("No trial components found for %s", job_arn) return start_arn - def _get_associations_dataframe(self, arn): + def _get_associations_dataframe(self, arn: str) -> DataFrame: """Create a data frame containing lineage association information. Args: @@ -148,17 +155,25 @@ def _get_associations_dataframe(self, arn): # no associations return None - upstream_associations = self._get_associations(dest_arn=arn) - downstream_associations = self._get_associations(src_arn=arn) - inputs = list(map(self._convert_input_association_to_df_row, upstream_associations)) - outputs = list(map(self._convert_output_association_to_df_row, downstream_associations)) - df = pd.DataFrame( + upstream_associations: Iterator[AssociationSummary] = self._get_associations(dest_arn=arn) + downstream_associations: Iterator[AssociationSummary] = self._get_associations(src_arn=arn) + inputs: list = list(map(self._convert_input_association_to_df_row, upstream_associations)) + outputs: list = list( + map(self._convert_output_association_to_df_row, downstream_associations) + ) + df: DataFrame = pd.DataFrame( inputs + outputs, - columns=["Name/Source", "Direction", "Type", "Association Type", "Lineage Type"], + columns=[ + "Name/Source", + "Direction", + "Type", + "Association Type", + "Lineage Type", + ], ) return df - def _get_start_arn_from_trial_component_name(self, tc_name): + def _get_start_arn_from_trial_component_name(self, tc_name: str) -> str: """Given a trial component name retrieve a start arn. Args: @@ -167,13 +182,13 @@ def _get_start_arn_from_trial_component_name(self, tc_name): Returns: str: The arn of the trial component. """ - response = self._session.sagemaker_client.describe_trial_component( + response: Any = self._session.sagemaker_client.describe_trial_component( TrialComponentName=tc_name ) - tc_arn = response["TrialComponentArn"] + tc_arn: str = response["TrialComponentArn"] return tc_arn - def _get_start_arn_from_model_package_arn(self, model_package_arn): + def _get_start_arn_from_model_package_arn(self, model_package_arn: str) -> str: """Given a model package arn retrieve the arn lineage entity. Args: @@ -182,16 +197,16 @@ def _get_start_arn_from_model_package_arn(self, model_package_arn): Returns: str: The arn of the lineage entity that represents the model package. """ - response = self._session.sagemaker_client.list_artifacts(SourceUri=model_package_arn) - artifacts = response["ArtifactSummaries"] - artifact_arn = None + response: Any = self._session.sagemaker_client.list_artifacts(SourceUri=model_package_arn) + artifacts: Any = response["ArtifactSummaries"] + artifact_arn: str = None if artifacts: artifact_arn = artifacts[0]["ArtifactArn"] else: logging.debug("No artifacts found for %s.", model_package_arn) return artifact_arn - def _get_start_arn_from_endpoint_arn(self, endpoint_arn): + def _get_start_arn_from_endpoint_arn(self, endpoint_arn: str) -> str: """Given an endpoint arn retrieve the arn of the lineage entity. Args: @@ -200,16 +215,18 @@ def _get_start_arn_from_endpoint_arn(self, endpoint_arn): Returns: str: The arn of the lineage entity that represents the model package. """ - response = self._session.sagemaker_client.list_contexts(SourceUri=endpoint_arn) - contexts = response["ContextSummaries"] - context_arn = None + response: Any = self._session.sagemaker_client.list_contexts(SourceUri=endpoint_arn) + contexts: Any = response["ContextSummaries"] + context_arn: str = None if contexts: context_arn = contexts[0]["ContextArn"] else: logging.debug("No contexts found for %s.", endpoint_arn) return context_arn - def _get_associations(self, src_arn=None, dest_arn=None): + def _get_associations( + self, src_arn: Optional[str] = None, dest_arn: Optional[str] = None + ) -> Iterator[AssociationSummary]: """Given an arn retrieve all associated lineage entities. The arn must be one of: experiment, trial, trial component, artifact, action, or context. @@ -223,14 +240,16 @@ def _get_associations(self, src_arn=None, dest_arn=None): entity of interest. """ if src_arn: - associations = Association.list(source_arn=src_arn, sagemaker_session=self._session) + associations: Iterator[AssociationSummary] = Association.list( + source_arn=src_arn, sagemaker_session=self._session + ) else: - associations = Association.list( + associations: Iterator[AssociationSummary] = Association.list( destination_arn=dest_arn, sagemaker_session=self._session ) return associations - def _convert_input_association_to_df_row(self, association): + def _convert_input_association_to_df_row(self, association) -> list: """Convert an input association to a data frame row. Args: @@ -247,7 +266,7 @@ def _convert_input_association_to_df_row(self, association): association.association_type, ) - def _convert_output_association_to_df_row(self, association): + def _convert_output_association_to_df_row(self, association) -> list: """Convert an output association to a data frame row. Args: @@ -264,7 +283,14 @@ def _convert_output_association_to_df_row(self, association): association.association_type, ) - def _convert_association_to_df_row(self, arn, name, direction, src_dest_type, association_type): + def _convert_association_to_df_row( + self, + arn: str, + name: str, + direction: str, + src_dest_type: str, + association_type: type, + ) -> list: """Convert association data into a data frame row. Args: @@ -284,7 +310,7 @@ def _convert_association_to_df_row(self, arn, name, direction, src_dest_type, as name = self._get_friendly_name(name, arn, entity_type) return [name, direction, src_dest_type, association_type, entity_type] - def _get_friendly_name(self, name, arn, entity_type): + def _get_friendly_name(self, name: str, arn: str, entity_type: str) -> str: """Get a human readable name from the association. Args: