diff --git a/git/refs/symbolic.py b/git/refs/symbolic.py index 22d9c1d51..64a6591aa 100644 --- a/git/refs/symbolic.py +++ b/git/refs/symbolic.py @@ -45,7 +45,7 @@ class SymbolicReference(object): _remote_common_path_default = "refs/remotes" _id_attribute_ = "name" - def __init__(self, repo, path): + def __init__(self, repo, path, check_path=None): self.repo = repo self.path = path diff --git a/git/remote.py b/git/remote.py index 194db9386..2eeafcc41 100644 --- a/git/remote.py +++ b/git/remote.py @@ -36,11 +36,18 @@ # typing------------------------------------------------------- -from typing import TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union, cast, overload + +from git.types import PathLike, Literal if TYPE_CHECKING: from git.repo.base import Repo + from git.objects.commit import Commit + from git.objects.blob import Blob + from git.objects.tree import Tree + from git.objects.tag import TagObject +flagKeyLiteral = Literal[' ', '!', '+', '-', '*', '=', 't'] # ------------------------------------------------------------- log = logging.getLogger('git.remote') @@ -52,7 +59,7 @@ #{ Utilities -def add_progress(kwargs, git, progress): +def add_progress(kwargs: Any, git: Git, progress: Optional[Callable[..., Any]]) -> Any: """Add the --progress flag to the given kwargs dict if supported by the git command. If the actual progress in the given progress instance is not given, we do not request any progress @@ -68,7 +75,23 @@ def add_progress(kwargs, git, progress): #} END utilities -def to_progress_instance(progress): +@overload +def to_progress_instance(progress: None) -> RemoteProgress: + ... + + +@overload +def to_progress_instance(progress: Callable[..., Any]) -> CallableRemoteProgress: + ... + + +@overload +def to_progress_instance(progress: RemoteProgress) -> RemoteProgress: + ... + + +def to_progress_instance(progress: Union[Callable[..., Any], RemoteProgress, None] + ) -> Union[RemoteProgress, CallableRemoteProgress]: """Given the 'progress' return a suitable object derived from RemoteProgress(). """ @@ -112,9 +135,10 @@ class PushInfo(object): '=': UP_TO_DATE, '!': ERROR} - def __init__(self, flags, local_ref, remote_ref_string, remote, old_commit=None, - summary=''): - """ Initialize a new instance """ + def __init__(self, flags: int, local_ref: Union[SymbolicReference, None], remote_ref_string: str, remote: 'Remote', + old_commit: Optional[str] = None, summary: str = '') -> None: + """ Initialize a new instance + local_ref: HEAD | Head | RemoteReference | TagReference | Reference | SymbolicReference | None """ self.flags = flags self.local_ref = local_ref self.remote_ref_string = remote_ref_string @@ -123,11 +147,11 @@ def __init__(self, flags, local_ref, remote_ref_string, remote, old_commit=None, self.summary = summary @property - def old_commit(self): + def old_commit(self) -> Union[str, SymbolicReference, 'Commit', 'TagObject', 'Blob', 'Tree', None]: return self._old_commit_sha and self._remote.repo.commit(self._old_commit_sha) or None @property - def remote_ref(self): + def remote_ref(self) -> Union[RemoteReference, TagReference]: """ :return: Remote Reference or TagReference in the local repository corresponding @@ -143,7 +167,7 @@ def remote_ref(self): # END @classmethod - def _from_line(cls, remote, line): + def _from_line(cls, remote, line: str) -> 'PushInfo': """Create a new PushInfo instance as parsed from line which is expected to be like refs/heads/master:refs/heads/master 05d2687..1d0568e as bytes""" control_character, from_to, summary = line.split('\t', 3) @@ -159,7 +183,7 @@ def _from_line(cls, remote, line): # from_to handling from_ref_string, to_ref_string = from_to.split(':') if flags & cls.DELETED: - from_ref = None + from_ref = None # type: Union[SymbolicReference, None] else: if from_ref_string == "(delete)": from_ref = None @@ -167,7 +191,7 @@ def _from_line(cls, remote, line): from_ref = Reference.from_path(remote.repo, from_ref_string) # commit handling, could be message or commit info - old_commit = None + old_commit = None # type: Optional[str] if summary.startswith('['): if "[rejected]" in summary: flags |= cls.REJECTED @@ -226,10 +250,10 @@ class FetchInfo(object): '=': HEAD_UPTODATE, ' ': FAST_FORWARD, '-': TAG_UPDATE, - } + } # type: Dict[flagKeyLiteral, int] @classmethod - def refresh(cls): + def refresh(cls) -> Literal[True]: """This gets called by the refresh function (see the top level __init__). """ @@ -252,7 +276,8 @@ def refresh(cls): return True - def __init__(self, ref, flags, note='', old_commit=None, remote_ref_path=None): + def __init__(self, ref: SymbolicReference, flags: int, note: str = '', old_commit: Optional['Commit'] = None, + remote_ref_path: Optional[PathLike] = None) -> None: """ Initialize a new instance """ @@ -262,21 +287,21 @@ def __init__(self, ref, flags, note='', old_commit=None, remote_ref_path=None): self.old_commit = old_commit self.remote_ref_path = remote_ref_path - def __str__(self): + def __str__(self) -> str: return self.name @property - def name(self): + def name(self) -> str: """:return: Name of our remote ref""" return self.ref.name @property - def commit(self): + def commit(self) -> 'Commit': """:return: Commit of our remote ref""" return self.ref.commit @classmethod - def _from_line(cls, repo, line, fetch_line): + def _from_line(cls, repo: 'Repo', line: str, fetch_line: str) -> 'FetchInfo': """Parse information from the given line as returned by git-fetch -v and return a new FetchInfo object representing this information. @@ -298,7 +323,9 @@ def _from_line(cls, repo, line, fetch_line): raise ValueError("Failed to parse line: %r" % line) # parse lines - control_character, operation, local_remote_ref, remote_local_ref, note = match.groups() + control_character, operation, local_remote_ref, remote_local_ref_str, note = match.groups() + control_character = cast(flagKeyLiteral, control_character) # can do this neater once 3.5 dropped + try: _new_hex_sha, _fetch_operation, fetch_note = fetch_line.split("\t") ref_type_name, fetch_note = fetch_note.split(' ', 1) @@ -338,7 +365,7 @@ def _from_line(cls, repo, line, fetch_line): # the fetch result is stored in FETCH_HEAD which destroys the rule we usually # have. In that case we use a symbolic reference which is detached ref_type = None - if remote_local_ref == "FETCH_HEAD": + if remote_local_ref_str == "FETCH_HEAD": ref_type = SymbolicReference elif ref_type_name == "tag" or is_tag_operation: # the ref_type_name can be branch, whereas we are still seeing a tag operation. It happens during @@ -366,21 +393,21 @@ def _from_line(cls, repo, line, fetch_line): # by the 'ref/' prefix. Otherwise even a tag could be in refs/remotes, which is when it will have the # 'tags/' subdirectory in its path. # We don't want to test for actual existence, but try to figure everything out analytically. - ref_path = None - remote_local_ref = remote_local_ref.strip() - if remote_local_ref.startswith(Reference._common_path_default + "/"): + ref_path = None # type: Optional[PathLike] + remote_local_ref_str = remote_local_ref_str.strip() + if remote_local_ref_str.startswith(Reference._common_path_default + "/"): # always use actual type if we get absolute paths # Will always be the case if something is fetched outside of refs/remotes (if its not a tag) - ref_path = remote_local_ref + ref_path = remote_local_ref_str if ref_type is not TagReference and not \ - remote_local_ref.startswith(RemoteReference._common_path_default + "/"): + remote_local_ref_str.startswith(RemoteReference._common_path_default + "/"): ref_type = Reference # END downgrade remote reference - elif ref_type is TagReference and 'tags/' in remote_local_ref: + elif ref_type is TagReference and 'tags/' in remote_local_ref_str: # even though its a tag, it is located in refs/remotes - ref_path = join_path(RemoteReference._common_path_default, remote_local_ref) + ref_path = join_path(RemoteReference._common_path_default, remote_local_ref_str) else: - ref_path = join_path(ref_type._common_path_default, remote_local_ref) + ref_path = join_path(ref_type._common_path_default, remote_local_ref_str) # END obtain refpath # even though the path could be within the git conventions, we make diff --git a/git/types.py b/git/types.py index 3e33ae0c9..40d4f7885 100644 --- a/git/types.py +++ b/git/types.py @@ -6,6 +6,11 @@ import sys from typing import Union, Any +if sys.version_info[:2] >= (3, 8): + from typing import Final, Literal # noqa: F401 +else: + from typing_extensions import Final, Literal # noqa: F401 + TBD = Any