diff --git a/git/index/base.py b/git/index/base.py index 044240602..e2b3f8fa4 100644 --- a/git/index/base.py +++ b/git/index/base.py @@ -568,7 +568,7 @@ def write_tree(self) -> Tree: # note: additional deserialization could be saved if write_tree_from_cache # would return sorted tree entries root_tree = Tree(self.repo, binsha, path='') - root_tree._cache = tree_items + root_tree._cache = tree_items # type: ignore return root_tree def _process_diff_args(self, args: List[Union[str, diff.Diffable, object]] diff --git a/git/index/fun.py b/git/index/fun.py index 3fded3473..ffd109b1f 100644 --- a/git/index/fun.py +++ b/git/index/fun.py @@ -53,7 +53,7 @@ from typing import (Dict, IO, List, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast) -from git.types import PathLike +from git.types import PathLike, TypeGuard if TYPE_CHECKING: from .base import IndexFile @@ -185,11 +185,17 @@ def read_header(stream: IO[bytes]) -> Tuple[int, int]: def entry_key(*entry: Union[BaseIndexEntry, PathLike, int]) -> Tuple[PathLike, int]: """:return: Key suitable to be used for the index.entries dictionary :param entry: One instance of type BaseIndexEntry or the path and the stage""" + + def is_entry_tuple(entry: Tuple) -> TypeGuard[Tuple[PathLike, int]]: + return isinstance(entry, tuple) and len(entry) == 2 + if len(entry) == 1: - entry_first = cast(BaseIndexEntry, entry[0]) # type: BaseIndexEntry + entry_first = entry[0] + assert isinstance(entry_first, BaseIndexEntry) return (entry_first.path, entry_first.stage) else: - entry = cast(Tuple[PathLike, int], tuple(entry)) + # entry = tuple(entry) + assert is_entry_tuple(entry) return entry # END handle entry @@ -293,7 +299,7 @@ def write_tree_from_cache(entries: List[IndexEntry], odb, sl: slice, si: int = 0 # finally create the tree sio = BytesIO() tree_to_stream(tree_items, sio.write) # converts bytes of each item[0] to str - tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items) # type: List[Tuple[str, int, str]] + tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items) sio.seek(0) istream = odb.store(IStream(str_tree_type, len(sio.getvalue()), sio)) diff --git a/git/objects/commit.py b/git/objects/commit.py index 26db6e36d..0b707450c 100644 --- a/git/objects/commit.py +++ b/git/objects/commit.py @@ -8,7 +8,7 @@ from git.util import ( hex_to_bin, Actor, - Iterable, + IterableObj, Stats, finalize_process ) @@ -47,7 +47,7 @@ __all__ = ('Commit', ) -class Commit(base.Object, Iterable, Diffable, Traversable, Serializable): +class Commit(base.Object, IterableObj, Diffable, Traversable, Serializable): """Wraps a git Commit object. diff --git a/git/objects/fun.py b/git/objects/fun.py index 9b36712e1..339a53b8c 100644 --- a/git/objects/fun.py +++ b/git/objects/fun.py @@ -1,10 +1,19 @@ """Module with functions which are supposed to be as fast as possible""" from stat import S_ISDIR + from git.compat import ( safe_decode, defenc ) +# typing ---------------------------------------------- + +from typing import List, Tuple + + +# --------------------------------------------------- + + __all__ = ('tree_to_stream', 'tree_entries_from_data', 'traverse_trees_recursive', 'traverse_tree_recursive') @@ -38,7 +47,7 @@ def tree_to_stream(entries, write): # END for each item -def tree_entries_from_data(data): +def tree_entries_from_data(data: bytes) -> List[Tuple[bytes, int, str]]: """Reads the binary representation of a tree and returns tuples of Tree items :param data: data block with tree data (as bytes) :return: list(tuple(binsha, mode, tree_relative_path), ...)""" @@ -72,8 +81,8 @@ def tree_entries_from_data(data): # default encoding for strings in git is utf8 # Only use the respective unicode object if the byte stream was encoded - name = data[ns:i] - name = safe_decode(name) + name_bytes = data[ns:i] + name = safe_decode(name_bytes) # byte is NULL, get next 20 i += 1 diff --git a/git/objects/submodule/base.py b/git/objects/submodule/base.py index 8cf4dd1eb..cbf6cd0db 100644 --- a/git/objects/submodule/base.py +++ b/git/objects/submodule/base.py @@ -3,7 +3,6 @@ import logging import os import stat -from typing import List from unittest import SkipTest import uuid @@ -27,12 +26,13 @@ from git.objects.base import IndexObject, Object from git.objects.util import Traversable from git.util import ( - Iterable, + IterableObj, join_path_native, to_native_path_linux, RemoteProgress, rmtree, - unbare_repo + unbare_repo, + IterableList ) from git.util import HIDE_WINDOWS_KNOWN_ERRORS @@ -47,6 +47,11 @@ ) +# typing ---------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- + __all__ = ["Submodule", "UpdateProgress"] @@ -74,7 +79,7 @@ class UpdateProgress(RemoteProgress): # IndexObject comes via util module, its a 'hacky' fix thanks to pythons import # mechanism which cause plenty of trouble of the only reason for packages and # modules is refactoring - subpackages shouldn't depend on parent packages -class Submodule(IndexObject, Iterable, Traversable): +class Submodule(IndexObject, IterableObj, Traversable): """Implements access to a git submodule. They are special in that their sha represents a commit in the submodule's repository which is to be checked out @@ -136,12 +141,12 @@ def _set_cache_(self, attr): # END handle attribute name @classmethod - def _get_intermediate_items(cls, item: 'Submodule') -> List['Submodule']: # type: ignore + def _get_intermediate_items(cls, item: 'Submodule') -> IterableList['Submodule']: """:return: all the submodules of our module repository""" try: return cls.list_items(item.module()) except InvalidGitRepositoryError: - return [] + return IterableList('') # END handle intermediate items @classmethod @@ -1153,7 +1158,7 @@ def name(self): """ return self._name - def config_reader(self): + def config_reader(self) -> SectionConstraint: """ :return: ConfigReader instance which allows you to qurey the configuration values of this submodule, as provided by the .gitmodules file @@ -1163,7 +1168,7 @@ def config_reader(self): :raise IOError: If the .gitmodules file/blob could not be read""" return self._config_parser_constrained(read_only=True) - def children(self): + def children(self) -> IterableList['Submodule']: """ :return: IterableList(Submodule, ...) an iterable list of submodules instances which are children of this submodule or 0 if the submodule is not checked out""" diff --git a/git/objects/submodule/util.py b/git/objects/submodule/util.py index 0b4ce3c53..b4796b300 100644 --- a/git/objects/submodule/util.py +++ b/git/objects/submodule/util.py @@ -4,6 +4,11 @@ from io import BytesIO import weakref +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .base import Submodule + __all__ = ('sm_section', 'sm_name', 'mkhead', 'find_first_remote_branch', 'SubmoduleConfigParser') @@ -60,12 +65,12 @@ def __init__(self, *args, **kwargs): super(SubmoduleConfigParser, self).__init__(*args, **kwargs) #{ Interface - def set_submodule(self, submodule): + def set_submodule(self, submodule: 'Submodule') -> None: """Set this instance's submodule. It must be called before the first write operation begins""" self._smref = weakref.ref(submodule) - def flush_to_index(self): + def flush_to_index(self) -> None: """Flush changes in our configuration file to the index""" assert self._smref is not None # should always have a file here diff --git a/git/objects/tree.py b/git/objects/tree.py index 29b2a6846..191fe27c3 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -20,21 +20,27 @@ # typing ------------------------------------------------- -from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING +from typing import Callable, Dict, Generic, Iterable, Iterator, List, Tuple, Type, TypeVar, Union, cast, TYPE_CHECKING + +from git.types import PathLike if TYPE_CHECKING: + from git.repo import Repo from io import BytesIO #-------------------------------------------------------- -cmp = lambda a, b: (a > b) - (a < b) +cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b) __all__ = ("TreeModifier", "Tree") +T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]]) + -def git_cmp(t1, t2): +def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int: a, b = t1[2], t2[2] + assert isinstance(a, str) and isinstance(b, str) # Need as mypy 9.0 cannot unpack TypeVar properly len_a, len_b = len(a), len(b) min_len = min(len_a, len_b) min_cmp = cmp(a[:min_len], b[:min_len]) @@ -45,9 +51,10 @@ def git_cmp(t1, t2): return len_a - len_b -def merge_sort(a, cmp): +def merge_sort(a: List[T_Tree_cache], + cmp: Callable[[T_Tree_cache, T_Tree_cache], int]) -> None: if len(a) < 2: - return + return None mid = len(a) // 2 lefthalf = a[:mid] @@ -80,7 +87,7 @@ def merge_sort(a, cmp): k = k + 1 -class TreeModifier(object): +class TreeModifier(Generic[T_Tree_cache], object): """A utility class providing methods to alter the underlying cache in a list-like fashion. @@ -88,10 +95,10 @@ class TreeModifier(object): the cache of a tree, will be sorted. Assuring it will be in a serializable state""" __slots__ = '_cache' - def __init__(self, cache): + def __init__(self, cache: List[T_Tree_cache]) -> None: self._cache = cache - def _index_by_name(self, name): + def _index_by_name(self, name: str) -> int: """:return: index of an item with name, or -1 if not found""" for i, t in enumerate(self._cache): if t[2] == name: @@ -101,7 +108,7 @@ def _index_by_name(self, name): return -1 #{ Interface - def set_done(self): + def set_done(self) -> 'TreeModifier': """Call this method once you are done modifying the tree information. It may be called several times, but be aware that each call will cause a sort operation @@ -111,7 +118,7 @@ def set_done(self): #} END interface #{ Mutators - def add(self, sha, mode, name, force=False): + def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> 'TreeModifier': """Add the given item to the tree. If an item with the given name already exists, nothing will be done, but a ValueError will be raised if the sha and mode of the existing item do not match the one you add, unless @@ -129,7 +136,9 @@ def add(self, sha, mode, name, force=False): sha = to_bin_sha(sha) index = self._index_by_name(name) - item = (sha, mode, name) + + assert isinstance(sha, bytes) and isinstance(mode, int) and isinstance(name, str) + item = cast(T_Tree_cache, (sha, mode, name)) # use Typeguard from typing-extensions 3.10.0 if index == -1: self._cache.append(item) else: @@ -144,14 +153,17 @@ def add(self, sha, mode, name, force=False): # END handle name exists return self - def add_unchecked(self, binsha, mode, name): + def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None: """Add the given item to the tree, its correctness is assumed, which puts the caller into responsibility to assure the input is correct. For more information on the parameters, see ``add`` :param binsha: 20 byte binary sha""" - self._cache.append((binsha, mode, name)) + assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str) + tree_cache = cast(T_Tree_cache, (binsha, mode, name)) - def __delitem__(self, name): + self._cache.append(tree_cache) + + def __delitem__(self, name: str) -> None: """Deletes an item with the given name if it exists""" index = self._index_by_name(name) if index > -1: @@ -182,29 +194,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): symlink_id = 0o12 tree_id = 0o04 - _map_id_to_type = { + _map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = { commit_id: Submodule, blob_id: Blob, symlink_id: Blob # tree id added once Tree is defined } - def __init__(self, repo, binsha, mode=tree_id << 12, path=None): + def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None): super(Tree, self).__init__(repo, binsha, mode, path) - @classmethod + @ classmethod def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore - ) -> Tuple['Tree', ...]: + ) -> Union[Tuple['Tree', ...], Tuple[()]]: if index_object.type == "tree": index_object = cast('Tree', index_object) return tuple(index_object._iter_convert_to_object(index_object._cache)) return () - def _set_cache_(self, attr): + def _set_cache_(self, attr: str) -> None: if attr == "_cache": # Set the data when we need it ostream = self.repo.odb.stream(self.binsha) - self._cache = tree_entries_from_data(ostream.read()) + self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read()) else: super(Tree, self)._set_cache_(attr) # END handle attribute @@ -221,7 +233,7 @@ def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]] raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e # END for each item - def join(self, file): + def join(self, file: str) -> Union[Blob, 'Tree', Submodule]: """Find the named object in this tree's contents :return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule`` @@ -254,26 +266,22 @@ def join(self, file): raise KeyError(msg % file) # END handle long paths - def __div__(self, file): - """For PY2 only""" - return self.join(file) - - def __truediv__(self, file): + def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]: """For PY3 only""" return self.join(file) - @property - def trees(self): + @ property + def trees(self) -> List['Tree']: """:return: list(Tree, ...) list of trees directly below this tree""" return [i for i in self if i.type == "tree"] - @property - def blobs(self): + @ property + def blobs(self) -> List['Blob']: """:return: list(Blob, ...) list of blobs directly below this tree""" return [i for i in self if i.type == "blob"] - @property - def cache(self): + @ property + def cache(self) -> TreeModifier: """ :return: An object allowing to modify the internal cache. This can be used to change the tree's contents. When done, make sure you call ``set_done`` @@ -289,16 +297,16 @@ def traverse(self, predicate=lambda i, d: True, return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self) # List protocol - def __getslice__(self, i, j): + def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]: return list(self._iter_convert_to_object(self._cache[i:j])) - def __iter__(self): + def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]: return self._iter_convert_to_object(self._cache) - def __len__(self): + def __len__(self) -> int: return len(self._cache) - def __getitem__(self, item): + def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]: if isinstance(item, int): info = self._cache[item] return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2])) @@ -310,7 +318,7 @@ def __getitem__(self, item): raise TypeError("Invalid index type: %r" % item) - def __contains__(self, item): + def __contains__(self, item: Union[IndexObject, PathLike]) -> bool: if isinstance(item, IndexObject): for info in self._cache: if item.binsha == info[0]: @@ -321,10 +329,11 @@ def __contains__(self, item): # compatibility # treat item as repo-relative path - path = self.path - for info in self._cache: - if item == join_path(path, info[2]): - return True + else: + path = self.path + for info in self._cache: + if item == join_path(path, info[2]): + return True # END for each item return False diff --git a/git/objects/util.py b/git/objects/util.py index 087f0166b..8b8148a9f 100644 --- a/git/objects/util.py +++ b/git/objects/util.py @@ -19,16 +19,18 @@ from datetime import datetime, timedelta, tzinfo # typing ------------------------------------------------------------ -from typing import (Any, Callable, Deque, Iterator, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast, overload) +from typing import (Any, Callable, Deque, Iterator, TypeVar, TYPE_CHECKING, Tuple, Type, Union, cast) if TYPE_CHECKING: from io import BytesIO, StringIO - from .submodule.base import Submodule + from .submodule.base import Submodule # noqa: F401 from .commit import Commit from .blob import Blob from .tag import TagObject from .tree import Tree from subprocess import Popen + +T_Iterableobj = TypeVar('T_Iterableobj') # -------------------------------------------------------------------- @@ -284,29 +286,8 @@ class Traversable(object): """ __slots__ = () - @overload @classmethod - def _get_intermediate_items(cls, item: 'Commit') -> Tuple['Commit', ...]: - ... - - @overload - @classmethod - def _get_intermediate_items(cls, item: 'Submodule') -> Tuple['Submodule', ...]: - ... - - @overload - @classmethod - def _get_intermediate_items(cls, item: 'Tree') -> Tuple['Tree', ...]: - ... - - @overload - @classmethod - def _get_intermediate_items(cls, item: 'Traversable') -> Tuple['Traversable', ...]: - ... - - @classmethod - def _get_intermediate_items(cls, item: 'Traversable' - ) -> Sequence['Traversable']: + def _get_intermediate_items(cls, item): """ Returns: Tuple of items connected to the given item. @@ -322,7 +303,7 @@ def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList: """ :return: IterableList with the results of the traversal as produced by traverse()""" - out = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses + out: IterableList = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses out.extend(self.traverse(*args, **kwargs)) return out diff --git a/git/refs/log.py b/git/refs/log.py index 363c3c5d5..f850ba24c 100644 --- a/git/refs/log.py +++ b/git/refs/log.py @@ -82,23 +82,23 @@ def new(cls, oldhexsha, newhexsha, actor, time, tz_offset, message): # skipcq: return RefLogEntry((oldhexsha, newhexsha, actor, (time, tz_offset), message)) @classmethod - def from_line(cls, line): + def from_line(cls, line: bytes) -> 'RefLogEntry': """:return: New RefLogEntry instance from the given revlog line. :param line: line bytes without trailing newline :raise ValueError: If line could not be parsed""" - line = line.decode(defenc) - fields = line.split('\t', 1) + line_str = line.decode(defenc) + fields = line_str.split('\t', 1) if len(fields) == 1: info, msg = fields[0], None elif len(fields) == 2: info, msg = fields else: raise ValueError("Line must have up to two TAB-separated fields." - " Got %s" % repr(line)) + " Got %s" % repr(line_str)) # END handle first split - oldhexsha = info[:40] # type: str - newhexsha = info[41:81] # type: str + oldhexsha = info[:40] + newhexsha = info[41:81] for hexsha in (oldhexsha, newhexsha): if not cls._re_hexsha_only.match(hexsha): raise ValueError("Invalid hexsha: %r" % (hexsha,)) diff --git a/git/refs/reference.py b/git/refs/reference.py index 9014f5558..8a9b04873 100644 --- a/git/refs/reference.py +++ b/git/refs/reference.py @@ -1,6 +1,6 @@ from git.util import ( LazyMixin, - Iterable, + IterableObj, ) from .symbolic import SymbolicReference @@ -23,7 +23,7 @@ def wrapper(self, *args): #}END utilities -class Reference(SymbolicReference, LazyMixin, Iterable): +class Reference(SymbolicReference, LazyMixin, IterableObj): """Represents a named reference to any object. Subclasses may apply restrictions though, i.e. Heads can only point to commits.""" diff --git a/git/remote.py b/git/remote.py index 6ea4b2a1a..a6232db32 100644 --- a/git/remote.py +++ b/git/remote.py @@ -13,7 +13,7 @@ from git.exc import GitCommandError from git.util import ( LazyMixin, - Iterable, + IterableObj, IterableList, RemoteProgress, CallableRemoteProgress, @@ -36,9 +36,9 @@ # typing------------------------------------------------------- -from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Union, cast, overload +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Union, overload -from git.types import PathLike, Literal, TBD +from git.types import PathLike, Literal, TBD, TypeGuard if TYPE_CHECKING: from git.repo.base import Repo @@ -47,9 +47,16 @@ from git.objects.tree import Tree from git.objects.tag import TagObject -flagKeyLiteral = Literal[' ', '!', '+', '-', '*', '=', 't'] +flagKeyLiteral = Literal[' ', '!', '+', '-', '*', '=', 't', '?'] + + +def is_flagKeyLiteral(inp: str) -> TypeGuard[flagKeyLiteral]: + return inp in [' ', '!', '+', '-', '=', '*', 't', '?'] + + # ------------------------------------------------------------- + log = logging.getLogger('git.remote') log.addHandler(logging.NullHandler()) @@ -107,7 +114,7 @@ def to_progress_instance(progress: Union[Callable[..., Any], RemoteProgress, Non return progress -class PushInfo(object): +class PushInfo(IterableObj, object): """ Carries information about the result of a push operation of a single head:: @@ -220,7 +227,7 @@ def _from_line(cls, remote: 'Remote', line: str) -> 'PushInfo': return PushInfo(flags, from_ref, to_ref_string, remote, old_commit, summary) -class FetchInfo(object): +class FetchInfo(IterableObj, object): """ Carries information about the results of a fetch operation of a single head:: @@ -325,7 +332,7 @@ def _from_line(cls, repo: 'Repo', line: str, fetch_line: str) -> 'FetchInfo': # parse lines control_character, operation, local_remote_ref, remote_local_ref_str, note = match.groups() - control_character = cast(flagKeyLiteral, control_character) # can do this neater once 3.5 dropped + assert is_flagKeyLiteral(control_character), f"{control_character}" try: _new_hex_sha, _fetch_operation, fetch_note = fetch_line.split("\t") @@ -421,7 +428,7 @@ def _from_line(cls, repo: 'Repo', line: str, fetch_line: str) -> 'FetchInfo': return cls(remote_local_ref, flags, note, old_commit, local_remote_ref) -class Remote(LazyMixin, Iterable): +class Remote(LazyMixin, IterableObj): """Provides easy read and write access to a git remote. @@ -552,8 +559,8 @@ def delete_url(self, url: str, **kwargs: Any) -> 'Remote': def urls(self) -> Iterator[str]: """:return: Iterator yielding all configured URL targets on a remote as strings""" try: - # can replace cast with type assert? - remote_details = cast(str, self.repo.git.remote("get-url", "--all", self.name)) + remote_details = self.repo.git.remote("get-url", "--all", self.name) + assert isinstance(remote_details, str) for line in remote_details.split('\n'): yield line except GitCommandError as ex: @@ -564,14 +571,16 @@ def urls(self) -> Iterator[str]: # if 'Unknown subcommand: get-url' in str(ex): try: - remote_details = cast(str, self.repo.git.remote("show", self.name)) + remote_details = self.repo.git.remote("show", self.name) + assert isinstance(remote_details, str) for line in remote_details.split('\n'): if ' Push URL:' in line: yield line.split(': ')[-1] except GitCommandError as _ex: if any(msg in str(_ex) for msg in ['correct access rights', 'cannot run ssh']): # If ssh is not setup to access this repository, see issue 694 - remote_details = cast(str, self.repo.git.config('--get-all', 'remote.%s.url' % self.name)) + remote_details = self.repo.git.config('--get-all', 'remote.%s.url' % self.name) + assert isinstance(remote_details, str) for line in remote_details.split('\n'): yield line else: @@ -580,18 +589,18 @@ def urls(self) -> Iterator[str]: raise ex @property - def refs(self) -> IterableList: + def refs(self) -> IterableList[RemoteReference]: """ :return: IterableList of RemoteReference objects. It is prefixed, allowing you to omit the remote path portion, i.e.:: remote.refs.master # yields RemoteReference('/refs/remotes/origin/master')""" - out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) + out_refs: IterableList[RemoteReference] = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) out_refs.extend(RemoteReference.list_items(self.repo, remote=self.name)) return out_refs @property - def stale_refs(self) -> IterableList: + def stale_refs(self) -> IterableList[Reference]: """ :return: IterableList RemoteReference objects that do not have a corresponding @@ -606,7 +615,7 @@ def stale_refs(self) -> IterableList: as well. This is a fix for the issue described here: https://github.com/gitpython-developers/GitPython/issues/260 """ - out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) + out_refs: IterableList[RemoteReference] = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) for line in self.repo.git.remote("prune", "--dry-run", self).splitlines()[2:]: # expecting # * [would prune] origin/new_branch @@ -681,11 +690,12 @@ def update(self, **kwargs: Any) -> 'Remote': return self def _get_fetch_info_from_stderr(self, proc: TBD, - progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList: + progress: Union[Callable[..., Any], RemoteProgress, None] + ) -> IterableList['FetchInfo']: progress = to_progress_instance(progress) # skip first line as it is some remote info we are not interested in - output = IterableList('name') + output: IterableList['FetchInfo'] = IterableList('name') # lines which are no progress are fetch info lines # this also waits for the command to finish @@ -741,7 +751,7 @@ def _get_fetch_info_from_stderr(self, proc: TBD, return output def _get_push_info(self, proc: TBD, - progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList: + progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList[PushInfo]: progress = to_progress_instance(progress) # read progress information from stderr @@ -749,7 +759,7 @@ def _get_push_info(self, proc: TBD, # read the lines manually as it will use carriage returns between the messages # to override the previous one. This is why we read the bytes manually progress_handler = progress.new_message_handler() - output = IterableList('push_infos') + output: IterableList[PushInfo] = IterableList('push_infos') def stdout_handler(line: str) -> None: try: @@ -785,7 +795,7 @@ def _assert_refspec(self) -> None: def fetch(self, refspec: Union[str, List[str], None] = None, progress: Union[Callable[..., Any], None] = None, - verbose: bool = True, **kwargs: Any) -> IterableList: + verbose: bool = True, **kwargs: Any) -> IterableList[FetchInfo]: """Fetch the latest changes for this remote :param refspec: @@ -832,7 +842,7 @@ def fetch(self, refspec: Union[str, List[str], None] = None, def pull(self, refspec: Union[str, List[str], None] = None, progress: Union[Callable[..., Any], None] = None, - **kwargs: Any) -> IterableList: + **kwargs: Any) -> IterableList[FetchInfo]: """Pull changes from the given branch, being the same as a fetch followed by a merge of branch with your local branch. @@ -853,7 +863,7 @@ def pull(self, refspec: Union[str, List[str], None] = None, def push(self, refspec: Union[str, List[str], None] = None, progress: Union[Callable[..., Any], None] = None, - **kwargs: Any) -> IterableList: + **kwargs: Any) -> IterableList[PushInfo]: """Push changes from source branch in refspec to target branch in refspec. :param refspec: see 'fetch' method diff --git a/git/repo/base.py b/git/repo/base.py index 5abd49618..52727504b 100644 --- a/git/repo/base.py +++ b/git/repo/base.py @@ -7,6 +7,7 @@ import os import re import warnings +from gitdb.db.loose import LooseObjectDB from gitdb.exc import BadObject @@ -100,7 +101,7 @@ class Repo(object): # Subclasses may easily bring in their own custom types by placing a constructor or type here GitCommandWrapperType = Git - def __init__(self, path: Optional[PathLike] = None, odbt: Type[GitCmdObjectDB] = GitCmdObjectDB, + def __init__(self, path: Optional[PathLike] = None, odbt: Type[LooseObjectDB] = GitCmdObjectDB, search_parent_directories: bool = False, expand_vars: bool = True) -> None: """Create a new Repo instance @@ -308,7 +309,7 @@ def bare(self) -> bool: return self._bare @property - def heads(self) -> 'IterableList': + def heads(self) -> 'IterableList[Head]': """A list of ``Head`` objects representing the branch heads in this repo @@ -316,7 +317,7 @@ def heads(self) -> 'IterableList': return Head.list_items(self) @property - def references(self) -> 'IterableList': + def references(self) -> 'IterableList[Reference]': """A list of Reference objects representing tags, heads and remote references. :return: IterableList(Reference, ...)""" @@ -341,7 +342,7 @@ def head(self) -> 'HEAD': return HEAD(self, 'HEAD') @property - def remotes(self) -> 'IterableList': + def remotes(self) -> 'IterableList[Remote]': """A list of Remote objects allowing to access and manipulate remotes :return: ``git.IterableList(Remote, ...)``""" return Remote.list_items(self) @@ -357,13 +358,13 @@ def remote(self, name: str = 'origin') -> 'Remote': #{ Submodules @property - def submodules(self) -> 'IterableList': + def submodules(self) -> 'IterableList[Submodule]': """ :return: git.IterableList(Submodule, ...) of direct submodules available from the current head""" return Submodule.list_items(self) - def submodule(self, name: str) -> 'IterableList': + def submodule(self, name: str) -> 'Submodule': """ :return: Submodule with the given name :raise ValueError: If no such submodule exists""" try: @@ -395,7 +396,7 @@ def submodule_update(self, *args: Any, **kwargs: Any) -> Iterator: #}END submodules @property - def tags(self) -> 'IterableList': + def tags(self) -> 'IterableList[TagReference]': """A list of ``Tag`` objects that are available in this repo :return: ``git.IterableList(TagReference, ...)`` """ return TagReference.list_items(self) diff --git a/git/types.py b/git/types.py index a410cb366..e3b49170d 100644 --- a/git/types.py +++ b/git/types.py @@ -4,12 +4,17 @@ import os import sys -from typing import Union, Any +from typing import Dict, Union, Any if sys.version_info[:2] >= (3, 8): - from typing import Final, Literal, SupportsIndex # noqa: F401 + from typing import Final, Literal, SupportsIndex, TypedDict # noqa: F401 else: - from typing_extensions import Final, Literal, SupportsIndex # noqa: F401 + from typing_extensions import Final, Literal, SupportsIndex, TypedDict # noqa: F401 + +if sys.version_info[:2] >= (3, 10): + from typing import TypeGuard # noqa: F401 +else: + from typing_extensions import TypeGuard # noqa: F401 if sys.version_info[:2] < (3, 9): @@ -22,3 +27,21 @@ TBD = Any Lit_config_levels = Literal['system', 'global', 'user', 'repository'] + + +class Files_TD(TypedDict): + insertions: int + deletions: int + lines: int + + +class Total_TD(TypedDict): + insertions: int + deletions: int + lines: int + files: int + + +class HSH_TD(TypedDict): + total: Total_TD + files: Dict[PathLike, Files_TD] diff --git a/git/util.py b/git/util.py index 516c315c1..eccaa74ed 100644 --- a/git/util.py +++ b/git/util.py @@ -18,18 +18,21 @@ import time from unittest import SkipTest from urllib.parse import urlsplit, urlunsplit +import warnings # typing --------------------------------------------------------- from typing import (Any, AnyStr, BinaryIO, Callable, Dict, Generator, IO, Iterator, List, - Optional, Pattern, Sequence, Tuple, Union, cast, TYPE_CHECKING, overload) + Optional, Pattern, Sequence, Tuple, TypeVar, Union, cast, TYPE_CHECKING, overload) import pathlib if TYPE_CHECKING: from git.remote import Remote from git.repo.base import Repo -from .types import PathLike, TBD, Literal, SupportsIndex + from git.config import GitConfigParser, SectionConstraint + +from .types import PathLike, Literal, SupportsIndex, HSH_TD, Files_TD # --------------------------------------------------------------------- @@ -81,7 +84,7 @@ def unbare_repo(func: Callable) -> Callable: encounter a bare repository""" @wraps(func) - def wrapper(self: 'Remote', *args: Any, **kwargs: Any) -> TBD: + def wrapper(self: 'Remote', *args: Any, **kwargs: Any) -> Callable: if self.repo.bare: raise InvalidGitRepositoryError("Method '%s' cannot operate on bare repositories" % func.__name__) # END bare method @@ -107,7 +110,7 @@ def rmtree(path: PathLike) -> None: :note: we use shutil rmtree but adjust its behaviour to see whether files that couldn't be deleted are read-only. Windows will not remove them in that case""" - def onerror(func: Callable, path: PathLike, exc_info: TBD) -> None: + def onerror(func: Callable, path: PathLike, exc_info: str) -> None: # Is the error an access error ? os.chmod(path, stat.S_IWUSR) @@ -447,7 +450,7 @@ class RemoteProgress(object): re_op_relative = re.compile(r"(remote: )?([\w\s]+):\s+(\d+)% \((\d+)/(\d+)\)(.*)") def __init__(self) -> None: - self._seen_ops = [] # type: List[TBD] + self._seen_ops = [] # type: List[int] self._cur_line = None # type: Optional[str] self.error_lines = [] # type: List[str] self.other_lines = [] # type: List[str] @@ -668,7 +671,8 @@ def _from_string(cls, string: str) -> 'Actor': # END handle name/email matching @classmethod - def _main_actor(cls, env_name: str, env_email: str, config_reader: Optional[TBD] = None) -> 'Actor': + def _main_actor(cls, env_name: str, env_email: str, + config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor': actor = Actor('', '') user_id = None # We use this to avoid multiple calls to getpass.getuser() @@ -697,7 +701,7 @@ def default_name() -> str: return actor @classmethod - def committer(cls, config_reader: Optional[TBD] = None) -> 'Actor': + def committer(cls, config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor': """ :return: Actor instance corresponding to the configured committer. It behaves similar to the git implementation, such that the environment will override @@ -708,7 +712,7 @@ def committer(cls, config_reader: Optional[TBD] = None) -> 'Actor': return cls._main_actor(cls.env_committer_name, cls.env_committer_email, config_reader) @classmethod - def author(cls, config_reader: Optional[TBD] = None) -> 'Actor': + def author(cls, config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor': """Same as committer(), but defines the main author. It may be specified in the environment, but defaults to the committer""" return cls._main_actor(cls.env_author_name, cls.env_author_email, config_reader) @@ -742,7 +746,9 @@ class Stats(object): files = number of changed files as int""" __slots__ = ("total", "files") - def __init__(self, total: Dict[str, Dict[str, int]], files: Dict[str, Dict[str, int]]): + from git.types import Total_TD, Files_TD + + def __init__(self, total: Total_TD, files: Dict[PathLike, Files_TD]): self.total = total self.files = files @@ -751,9 +757,13 @@ def _list_from_string(cls, repo: 'Repo', text: str) -> 'Stats': """Create a Stat object from output retrieved by git-diff. :return: git.Stat""" - hsh = {'total': {'insertions': 0, 'deletions': 0, 'lines': 0, 'files': 0}, - 'files': {} - } # type: Dict[str, Dict[str, TBD]] ## need typeddict or refactor for mypy + + hsh: HSH_TD = {'total': {'insertions': 0, + 'deletions': 0, + 'lines': 0, + 'files': 0}, + 'files': {} + } for line in text.splitlines(): (raw_insertions, raw_deletions, filename) = line.split("\t") insertions = raw_insertions != '-' and int(raw_insertions) or 0 @@ -762,9 +772,10 @@ def _list_from_string(cls, repo: 'Repo', text: str) -> 'Stats': hsh['total']['deletions'] += deletions hsh['total']['lines'] += insertions + deletions hsh['total']['files'] += 1 - hsh['files'][filename.strip()] = {'insertions': insertions, - 'deletions': deletions, - 'lines': insertions + deletions} + files_dict: Files_TD = {'insertions': insertions, + 'deletions': deletions, + 'lines': insertions + deletions} + hsh['files'][filename.strip()] = files_dict return Stats(hsh['total'], hsh['files']) @@ -920,7 +931,10 @@ def _obtain_lock(self) -> None: # END endless loop -class IterableList(list): +T = TypeVar('T', bound='IterableObj') + + +class IterableList(List[T]): """ List of iterable objects allowing to query an object by id or by named index:: @@ -930,6 +944,9 @@ class IterableList(list): heads['master'] heads[0] + Iterable parent objects = [Commit, SubModule, Reference, FetchInfo, PushInfo] + Iterable via inheritance = [Head, TagReference, RemoteReference] + ] It requires an id_attribute name to be set which will be queried from its contained items to have a means for comparison. @@ -938,7 +955,7 @@ class IterableList(list): can be left out.""" __slots__ = ('_id_attr', '_prefix') - def __new__(cls, id_attr: str, prefix: str = '') -> 'IterableList': + def __new__(cls, id_attr: str, prefix: str = '') -> 'IterableList[IterableObj]': return super(IterableList, cls).__new__(cls) def __init__(self, id_attr: str, prefix: str = '') -> None: @@ -1007,16 +1024,29 @@ def __delitem__(self, index: Union[SupportsIndex, int, slice, str]) -> Any: list.__delitem__(self, delindex) +class IterableClassWatcher(type): + def __init__(cls, name, bases, clsdict): + for base in bases: + if type(base) == IterableClassWatcher: + warnings.warn(f"GitPython Iterable subclassed by {name}. " + "Iterable is deprecated due to naming clash, " + "Use IterableObj instead \n", + DeprecationWarning, + stacklevel=2) + + class Iterable(object): """Defines an interface for iterable items which is to assure a uniform way to retrieve and iterate items within the git repository""" __slots__ = () _id_attribute_ = "attribute that most suitably identifies your instance" + __metaclass__ = IterableClassWatcher @classmethod - def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> 'IterableList': + def list_items(cls, repo, *args, **kwargs): """ + Deprecaated, use IterableObj instead. Find all items of this type - subclasses can specify args and kwargs differently. If no args are given, subclasses are obliged to return all items if no additional arguments arg given. @@ -1029,7 +1059,35 @@ def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> 'IterableList': return out_list @classmethod - def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> Iterator[TBD]: + def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any): + # return typed to be compatible with subtypes e.g. Remote + """For more information about the arguments, see list_items + :return: iterator yielding Items""" + raise NotImplementedError("To be implemented by Subclass") + + +class IterableObj(): + """Defines an interface for iterable items which is to assure a uniform + way to retrieve and iterate items within the git repository""" + __slots__ = () + _id_attribute_ = "attribute that most suitably identifies your instance" + + @classmethod + def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> IterableList[T]: + """ + Find all items of this type - subclasses can specify args and kwargs differently. + If no args are given, subclasses are obliged to return all items if no additional + arguments arg given. + + :note: Favor the iter_items method as it will + + :return:list(Item,...) list of item instances""" + out_list: IterableList = IterableList(cls._id_attribute_) + out_list.extend(cls.iter_items(repo, *args, **kwargs)) + return out_list + + @classmethod + def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> Iterator[T]: # return typed to be compatible with subtypes e.g. Remote """For more information about the arguments, see list_items :return: iterator yielding Items""" diff --git a/requirements.txt b/requirements.txt index 7159416a9..a20310fb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ gitdb>=4.0.1,<5 -typing-extensions>=3.7.4.3;python_version<"3.8" +typing-extensions>=3.7.4.3;python_version<"3.10" diff --git a/test-requirements.txt b/test-requirements.txt index 16dc0d2c1..ab3f86109 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,4 +5,4 @@ tox virtualenv nose gitdb>=4.0.1,<5 -typing-extensions>=3.7.4.3;python_version<"3.8" +typing-extensions>=3.7.4.3;python_version<"3.10"