Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2d2ff03

Browse files
authoredJun 26, 2021
Merge pull request #1279 from Yobmod/main
Finish typing object, improve verious other types.
2 parents 703280b + 5d7b8ba commit 2d2ff03

File tree

16 files changed

+255
-148
lines changed

16 files changed

+255
-148
lines changed
 

‎git/index/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def write_tree(self) -> Tree:
568568
# note: additional deserialization could be saved if write_tree_from_cache
569569
# would return sorted tree entries
570570
root_tree = Tree(self.repo, binsha, path='')
571-
root_tree._cache = tree_items
571+
root_tree._cache = tree_items # type: ignore
572572
return root_tree
573573

574574
def _process_diff_args(self, args: List[Union[str, diff.Diffable, object]]

‎git/index/fun.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454
from typing import (Dict, IO, List, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast)
5555

56-
from git.types import PathLike
56+
from git.types import PathLike, TypeGuard
5757

5858
if TYPE_CHECKING:
5959
from .base import IndexFile
@@ -185,11 +185,17 @@ def read_header(stream: IO[bytes]) -> Tuple[int, int]:
185185
def entry_key(*entry: Union[BaseIndexEntry, PathLike, int]) -> Tuple[PathLike, int]:
186186
""":return: Key suitable to be used for the index.entries dictionary
187187
:param entry: One instance of type BaseIndexEntry or the path and the stage"""
188+
189+
def is_entry_tuple(entry: Tuple) -> TypeGuard[Tuple[PathLike, int]]:
190+
return isinstance(entry, tuple) and len(entry) == 2
191+
188192
if len(entry) == 1:
189-
entry_first = cast(BaseIndexEntry, entry[0]) # type: BaseIndexEntry
193+
entry_first = entry[0]
194+
assert isinstance(entry_first, BaseIndexEntry)
190195
return (entry_first.path, entry_first.stage)
191196
else:
192-
entry = cast(Tuple[PathLike, int], tuple(entry))
197+
# entry = tuple(entry)
198+
assert is_entry_tuple(entry)
193199
return entry
194200
# END handle entry
195201

@@ -293,7 +299,7 @@ def write_tree_from_cache(entries: List[IndexEntry], odb, sl: slice, si: int = 0
293299
# finally create the tree
294300
sio = BytesIO()
295301
tree_to_stream(tree_items, sio.write) # converts bytes of each item[0] to str
296-
tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items) # type: List[Tuple[str, int, str]]
302+
tree_items_stringified = cast(List[Tuple[str, int, str]], tree_items)
297303
sio.seek(0)
298304

299305
istream = odb.store(IStream(str_tree_type, len(sio.getvalue()), sio))

‎git/objects/commit.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from git.util import (
99
hex_to_bin,
1010
Actor,
11-
Iterable,
11+
IterableObj,
1212
Stats,
1313
finalize_process
1414
)
@@ -47,7 +47,7 @@
4747
__all__ = ('Commit', )
4848

4949

50-
class Commit(base.Object, Iterable, Diffable, Traversable, Serializable):
50+
class Commit(base.Object, IterableObj, Diffable, Traversable, Serializable):
5151

5252
"""Wraps a git Commit object.
5353

‎git/objects/fun.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
"""Module with functions which are supposed to be as fast as possible"""
22
from stat import S_ISDIR
3+
34
from git.compat import (
45
safe_decode,
56
defenc
67
)
78

9+
# typing ----------------------------------------------
10+
11+
from typing import List, Tuple
12+
13+
14+
# ---------------------------------------------------
15+
16+
817
__all__ = ('tree_to_stream', 'tree_entries_from_data', 'traverse_trees_recursive',
918
'traverse_tree_recursive')
1019

@@ -38,7 +47,7 @@ def tree_to_stream(entries, write):
3847
# END for each item
3948

4049

41-
def tree_entries_from_data(data):
50+
def tree_entries_from_data(data: bytes) -> List[Tuple[bytes, int, str]]:
4251
"""Reads the binary representation of a tree and returns tuples of Tree items
4352
:param data: data block with tree data (as bytes)
4453
:return: list(tuple(binsha, mode, tree_relative_path), ...)"""
@@ -72,8 +81,8 @@ def tree_entries_from_data(data):
7281

7382
# default encoding for strings in git is utf8
7483
# Only use the respective unicode object if the byte stream was encoded
75-
name = data[ns:i]
76-
name = safe_decode(name)
84+
name_bytes = data[ns:i]
85+
name = safe_decode(name_bytes)
7786

7887
# byte is NULL, get next 20
7988
i += 1

‎git/objects/submodule/base.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import os
55
import stat
6-
from typing import List
76
from unittest import SkipTest
87
import uuid
98

@@ -27,12 +26,13 @@
2726
from git.objects.base import IndexObject, Object
2827
from git.objects.util import Traversable
2928
from git.util import (
30-
Iterable,
29+
IterableObj,
3130
join_path_native,
3231
to_native_path_linux,
3332
RemoteProgress,
3433
rmtree,
35-
unbare_repo
34+
unbare_repo,
35+
IterableList
3636
)
3737
from git.util import HIDE_WINDOWS_KNOWN_ERRORS
3838

@@ -47,6 +47,11 @@
4747
)
4848

4949

50+
# typing ----------------------------------------------------------------------
51+
52+
53+
# -----------------------------------------------------------------------------
54+
5055
__all__ = ["Submodule", "UpdateProgress"]
5156

5257

@@ -74,7 +79,7 @@ class UpdateProgress(RemoteProgress):
7479
# IndexObject comes via util module, its a 'hacky' fix thanks to pythons import
7580
# mechanism which cause plenty of trouble of the only reason for packages and
7681
# modules is refactoring - subpackages shouldn't depend on parent packages
77-
class Submodule(IndexObject, Iterable, Traversable):
82+
class Submodule(IndexObject, IterableObj, Traversable):
7883

7984
"""Implements access to a git submodule. They are special in that their sha
8085
represents a commit in the submodule's repository which is to be checked out
@@ -136,12 +141,12 @@ def _set_cache_(self, attr):
136141
# END handle attribute name
137142

138143
@classmethod
139-
def _get_intermediate_items(cls, item: 'Submodule') -> List['Submodule']: # type: ignore
144+
def _get_intermediate_items(cls, item: 'Submodule') -> IterableList['Submodule']:
140145
""":return: all the submodules of our module repository"""
141146
try:
142147
return cls.list_items(item.module())
143148
except InvalidGitRepositoryError:
144-
return []
149+
return IterableList('')
145150
# END handle intermediate items
146151

147152
@classmethod
@@ -1153,7 +1158,7 @@ def name(self):
11531158
"""
11541159
return self._name
11551160

1156-
def config_reader(self):
1161+
def config_reader(self) -> SectionConstraint:
11571162
"""
11581163
:return: ConfigReader instance which allows you to qurey the configuration values
11591164
of this submodule, as provided by the .gitmodules file
@@ -1163,7 +1168,7 @@ def config_reader(self):
11631168
:raise IOError: If the .gitmodules file/blob could not be read"""
11641169
return self._config_parser_constrained(read_only=True)
11651170

1166-
def children(self):
1171+
def children(self) -> IterableList['Submodule']:
11671172
"""
11681173
:return: IterableList(Submodule, ...) an iterable list of submodules instances
11691174
which are children of this submodule or 0 if the submodule is not checked out"""

‎git/objects/submodule/util.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
from io import BytesIO
55
import weakref
66

7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from .base import Submodule
11+
712
__all__ = ('sm_section', 'sm_name', 'mkhead', 'find_first_remote_branch',
813
'SubmoduleConfigParser')
914

@@ -60,12 +65,12 @@ def __init__(self, *args, **kwargs):
6065
super(SubmoduleConfigParser, self).__init__(*args, **kwargs)
6166

6267
#{ Interface
63-
def set_submodule(self, submodule):
68+
def set_submodule(self, submodule: 'Submodule') -> None:
6469
"""Set this instance's submodule. It must be called before
6570
the first write operation begins"""
6671
self._smref = weakref.ref(submodule)
6772

68-
def flush_to_index(self):
73+
def flush_to_index(self) -> None:
6974
"""Flush changes in our configuration file to the index"""
7075
assert self._smref is not None
7176
# should always have a file here

‎git/objects/tree.py

+50-41
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,27 @@
2020

2121
# typing -------------------------------------------------
2222

23-
from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING
23+
from typing import Callable, Dict, Generic, Iterable, Iterator, List, Tuple, Type, TypeVar, Union, cast, TYPE_CHECKING
24+
25+
from git.types import PathLike
2426

2527
if TYPE_CHECKING:
28+
from git.repo import Repo
2629
from io import BytesIO
2730

2831
#--------------------------------------------------------
2932

3033

31-
cmp = lambda a, b: (a > b) - (a < b)
34+
cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b)
3235

3336
__all__ = ("TreeModifier", "Tree")
3437

38+
T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]])
39+
3540

36-
def git_cmp(t1, t2):
41+
def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int:
3742
a, b = t1[2], t2[2]
43+
assert isinstance(a, str) and isinstance(b, str) # Need as mypy 9.0 cannot unpack TypeVar properly
3844
len_a, len_b = len(a), len(b)
3945
min_len = min(len_a, len_b)
4046
min_cmp = cmp(a[:min_len], b[:min_len])
@@ -45,9 +51,10 @@ def git_cmp(t1, t2):
4551
return len_a - len_b
4652

4753

48-
def merge_sort(a, cmp):
54+
def merge_sort(a: List[T_Tree_cache],
55+
cmp: Callable[[T_Tree_cache, T_Tree_cache], int]) -> None:
4956
if len(a) < 2:
50-
return
57+
return None
5158

5259
mid = len(a) // 2
5360
lefthalf = a[:mid]
@@ -80,18 +87,18 @@ def merge_sort(a, cmp):
8087
k = k + 1
8188

8289

83-
class TreeModifier(object):
90+
class TreeModifier(Generic[T_Tree_cache], object):
8491

8592
"""A utility class providing methods to alter the underlying cache in a list-like fashion.
8693
8794
Once all adjustments are complete, the _cache, which really is a reference to
8895
the cache of a tree, will be sorted. Assuring it will be in a serializable state"""
8996
__slots__ = '_cache'
9097

91-
def __init__(self, cache):
98+
def __init__(self, cache: List[T_Tree_cache]) -> None:
9299
self._cache = cache
93100

94-
def _index_by_name(self, name):
101+
def _index_by_name(self, name: str) -> int:
95102
""":return: index of an item with name, or -1 if not found"""
96103
for i, t in enumerate(self._cache):
97104
if t[2] == name:
@@ -101,7 +108,7 @@ def _index_by_name(self, name):
101108
return -1
102109

103110
#{ Interface
104-
def set_done(self):
111+
def set_done(self) -> 'TreeModifier':
105112
"""Call this method once you are done modifying the tree information.
106113
It may be called several times, but be aware that each call will cause
107114
a sort operation
@@ -111,7 +118,7 @@ def set_done(self):
111118
#} END interface
112119

113120
#{ Mutators
114-
def add(self, sha, mode, name, force=False):
121+
def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> 'TreeModifier':
115122
"""Add the given item to the tree. If an item with the given name already
116123
exists, nothing will be done, but a ValueError will be raised if the
117124
sha and mode of the existing item do not match the one you add, unless
@@ -129,7 +136,9 @@ def add(self, sha, mode, name, force=False):
129136

130137
sha = to_bin_sha(sha)
131138
index = self._index_by_name(name)
132-
item = (sha, mode, name)
139+
140+
assert isinstance(sha, bytes) and isinstance(mode, int) and isinstance(name, str)
141+
item = cast(T_Tree_cache, (sha, mode, name)) # use Typeguard from typing-extensions 3.10.0
133142
if index == -1:
134143
self._cache.append(item)
135144
else:
@@ -144,14 +153,17 @@ def add(self, sha, mode, name, force=False):
144153
# END handle name exists
145154
return self
146155

147-
def add_unchecked(self, binsha, mode, name):
156+
def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None:
148157
"""Add the given item to the tree, its correctness is assumed, which
149158
puts the caller into responsibility to assure the input is correct.
150159
For more information on the parameters, see ``add``
151160
:param binsha: 20 byte binary sha"""
152-
self._cache.append((binsha, mode, name))
161+
assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str)
162+
tree_cache = cast(T_Tree_cache, (binsha, mode, name))
153163

154-
def __delitem__(self, name):
164+
self._cache.append(tree_cache)
165+
166+
def __delitem__(self, name: str) -> None:
155167
"""Deletes an item with the given name if it exists"""
156168
index = self._index_by_name(name)
157169
if index > -1:
@@ -182,29 +194,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
182194
symlink_id = 0o12
183195
tree_id = 0o04
184196

185-
_map_id_to_type = {
197+
_map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
186198
commit_id: Submodule,
187199
blob_id: Blob,
188200
symlink_id: Blob
189201
# tree id added once Tree is defined
190202
}
191203

192-
def __init__(self, repo, binsha, mode=tree_id << 12, path=None):
204+
def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None):
193205
super(Tree, self).__init__(repo, binsha, mode, path)
194206

195-
@classmethod
207+
@ classmethod
196208
def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore
197-
) -> Tuple['Tree', ...]:
209+
) -> Union[Tuple['Tree', ...], Tuple[()]]:
198210
if index_object.type == "tree":
199211
index_object = cast('Tree', index_object)
200212
return tuple(index_object._iter_convert_to_object(index_object._cache))
201213
return ()
202214

203-
def _set_cache_(self, attr):
215+
def _set_cache_(self, attr: str) -> None:
204216
if attr == "_cache":
205217
# Set the data when we need it
206218
ostream = self.repo.odb.stream(self.binsha)
207-
self._cache = tree_entries_from_data(ostream.read())
219+
self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read())
208220
else:
209221
super(Tree, self)._set_cache_(attr)
210222
# END handle attribute
@@ -221,7 +233,7 @@ def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
221233
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
222234
# END for each item
223235

224-
def join(self, file):
236+
def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
225237
"""Find the named object in this tree's contents
226238
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
227239
@@ -254,26 +266,22 @@ def join(self, file):
254266
raise KeyError(msg % file)
255267
# END handle long paths
256268

257-
def __div__(self, file):
258-
"""For PY2 only"""
259-
return self.join(file)
260-
261-
def __truediv__(self, file):
269+
def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
262270
"""For PY3 only"""
263271
return self.join(file)
264272

265-
@property
266-
def trees(self):
273+
@ property
274+
def trees(self) -> List['Tree']:
267275
""":return: list(Tree, ...) list of trees directly below this tree"""
268276
return [i for i in self if i.type == "tree"]
269277

270-
@property
271-
def blobs(self):
278+
@ property
279+
def blobs(self) -> List['Blob']:
272280
""":return: list(Blob, ...) list of blobs directly below this tree"""
273281
return [i for i in self if i.type == "blob"]
274282

275-
@property
276-
def cache(self):
283+
@ property
284+
def cache(self) -> TreeModifier:
277285
"""
278286
:return: An object allowing to modify the internal cache. This can be used
279287
to change the tree's contents. When done, make sure you call ``set_done``
@@ -289,16 +297,16 @@ def traverse(self, predicate=lambda i, d: True,
289297
return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
290298

291299
# List protocol
292-
def __getslice__(self, i, j):
300+
def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
293301
return list(self._iter_convert_to_object(self._cache[i:j]))
294302

295-
def __iter__(self):
303+
def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
296304
return self._iter_convert_to_object(self._cache)
297305

298-
def __len__(self):
306+
def __len__(self) -> int:
299307
return len(self._cache)
300308

301-
def __getitem__(self, item):
309+
def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
302310
if isinstance(item, int):
303311
info = self._cache[item]
304312
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
@@ -310,7 +318,7 @@ def __getitem__(self, item):
310318

311319
raise TypeError("Invalid index type: %r" % item)
312320

313-
def __contains__(self, item):
321+
def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
314322
if isinstance(item, IndexObject):
315323
for info in self._cache:
316324
if item.binsha == info[0]:
@@ -321,10 +329,11 @@ def __contains__(self, item):
321329
# compatibility
322330

323331
# treat item as repo-relative path
324-
path = self.path
325-
for info in self._cache:
326-
if item == join_path(path, info[2]):
327-
return True
332+
else:
333+
path = self.path
334+
for info in self._cache:
335+
if item == join_path(path, info[2]):
336+
return True
328337
# END for each item
329338
return False
330339

‎git/objects/util.py

+6-25
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@
1919
from datetime import datetime, timedelta, tzinfo
2020

2121
# typing ------------------------------------------------------------
22-
from typing import (Any, Callable, Deque, Iterator, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast, overload)
22+
from typing import (Any, Callable, Deque, Iterator, TypeVar, TYPE_CHECKING, Tuple, Type, Union, cast)
2323

2424
if TYPE_CHECKING:
2525
from io import BytesIO, StringIO
26-
from .submodule.base import Submodule
26+
from .submodule.base import Submodule # noqa: F401
2727
from .commit import Commit
2828
from .blob import Blob
2929
from .tag import TagObject
3030
from .tree import Tree
3131
from subprocess import Popen
32+
33+
T_Iterableobj = TypeVar('T_Iterableobj')
3234

3335
# --------------------------------------------------------------------
3436

@@ -284,29 +286,8 @@ class Traversable(object):
284286
"""
285287
__slots__ = ()
286288

287-
@overload
288289
@classmethod
289-
def _get_intermediate_items(cls, item: 'Commit') -> Tuple['Commit', ...]:
290-
...
291-
292-
@overload
293-
@classmethod
294-
def _get_intermediate_items(cls, item: 'Submodule') -> Tuple['Submodule', ...]:
295-
...
296-
297-
@overload
298-
@classmethod
299-
def _get_intermediate_items(cls, item: 'Tree') -> Tuple['Tree', ...]:
300-
...
301-
302-
@overload
303-
@classmethod
304-
def _get_intermediate_items(cls, item: 'Traversable') -> Tuple['Traversable', ...]:
305-
...
306-
307-
@classmethod
308-
def _get_intermediate_items(cls, item: 'Traversable'
309-
) -> Sequence['Traversable']:
290+
def _get_intermediate_items(cls, item):
310291
"""
311292
Returns:
312293
Tuple of items connected to the given item.
@@ -322,7 +303,7 @@ def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList:
322303
"""
323304
:return: IterableList with the results of the traversal as produced by
324305
traverse()"""
325-
out = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses
306+
out: IterableList = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses
326307
out.extend(self.traverse(*args, **kwargs))
327308
return out
328309

‎git/refs/log.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,23 +82,23 @@ def new(cls, oldhexsha, newhexsha, actor, time, tz_offset, message): # skipcq:
8282
return RefLogEntry((oldhexsha, newhexsha, actor, (time, tz_offset), message))
8383

8484
@classmethod
85-
def from_line(cls, line):
85+
def from_line(cls, line: bytes) -> 'RefLogEntry':
8686
""":return: New RefLogEntry instance from the given revlog line.
8787
:param line: line bytes without trailing newline
8888
:raise ValueError: If line could not be parsed"""
89-
line = line.decode(defenc)
90-
fields = line.split('\t', 1)
89+
line_str = line.decode(defenc)
90+
fields = line_str.split('\t', 1)
9191
if len(fields) == 1:
9292
info, msg = fields[0], None
9393
elif len(fields) == 2:
9494
info, msg = fields
9595
else:
9696
raise ValueError("Line must have up to two TAB-separated fields."
97-
" Got %s" % repr(line))
97+
" Got %s" % repr(line_str))
9898
# END handle first split
9999

100-
oldhexsha = info[:40] # type: str
101-
newhexsha = info[41:81] # type: str
100+
oldhexsha = info[:40]
101+
newhexsha = info[41:81]
102102
for hexsha in (oldhexsha, newhexsha):
103103
if not cls._re_hexsha_only.match(hexsha):
104104
raise ValueError("Invalid hexsha: %r" % (hexsha,))

‎git/refs/reference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from git.util import (
22
LazyMixin,
3-
Iterable,
3+
IterableObj,
44
)
55
from .symbolic import SymbolicReference
66

@@ -23,7 +23,7 @@ def wrapper(self, *args):
2323
#}END utilities
2424

2525

26-
class Reference(SymbolicReference, LazyMixin, Iterable):
26+
class Reference(SymbolicReference, LazyMixin, IterableObj):
2727

2828
"""Represents a named reference to any object. Subclasses may apply restrictions though,
2929
i.e. Heads can only point to commits."""

‎git/remote.py

+33-23
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from git.exc import GitCommandError
1414
from git.util import (
1515
LazyMixin,
16-
Iterable,
16+
IterableObj,
1717
IterableList,
1818
RemoteProgress,
1919
CallableRemoteProgress,
@@ -36,9 +36,9 @@
3636

3737
# typing-------------------------------------------------------
3838

39-
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Union, cast, overload
39+
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TYPE_CHECKING, Union, overload
4040

41-
from git.types import PathLike, Literal, TBD
41+
from git.types import PathLike, Literal, TBD, TypeGuard
4242

4343
if TYPE_CHECKING:
4444
from git.repo.base import Repo
@@ -47,9 +47,16 @@
4747
from git.objects.tree import Tree
4848
from git.objects.tag import TagObject
4949

50-
flagKeyLiteral = Literal[' ', '!', '+', '-', '*', '=', 't']
50+
flagKeyLiteral = Literal[' ', '!', '+', '-', '*', '=', 't', '?']
51+
52+
53+
def is_flagKeyLiteral(inp: str) -> TypeGuard[flagKeyLiteral]:
54+
return inp in [' ', '!', '+', '-', '=', '*', 't', '?']
55+
56+
5157
# -------------------------------------------------------------
5258

59+
5360
log = logging.getLogger('git.remote')
5461
log.addHandler(logging.NullHandler())
5562

@@ -107,7 +114,7 @@ def to_progress_instance(progress: Union[Callable[..., Any], RemoteProgress, Non
107114
return progress
108115

109116

110-
class PushInfo(object):
117+
class PushInfo(IterableObj, object):
111118
"""
112119
Carries information about the result of a push operation of a single head::
113120
@@ -220,7 +227,7 @@ def _from_line(cls, remote: 'Remote', line: str) -> 'PushInfo':
220227
return PushInfo(flags, from_ref, to_ref_string, remote, old_commit, summary)
221228

222229

223-
class FetchInfo(object):
230+
class FetchInfo(IterableObj, object):
224231

225232
"""
226233
Carries information about the results of a fetch operation of a single head::
@@ -325,7 +332,7 @@ def _from_line(cls, repo: 'Repo', line: str, fetch_line: str) -> 'FetchInfo':
325332

326333
# parse lines
327334
control_character, operation, local_remote_ref, remote_local_ref_str, note = match.groups()
328-
control_character = cast(flagKeyLiteral, control_character) # can do this neater once 3.5 dropped
335+
assert is_flagKeyLiteral(control_character), f"{control_character}"
329336

330337
try:
331338
_new_hex_sha, _fetch_operation, fetch_note = fetch_line.split("\t")
@@ -421,7 +428,7 @@ def _from_line(cls, repo: 'Repo', line: str, fetch_line: str) -> 'FetchInfo':
421428
return cls(remote_local_ref, flags, note, old_commit, local_remote_ref)
422429

423430

424-
class Remote(LazyMixin, Iterable):
431+
class Remote(LazyMixin, IterableObj):
425432

426433
"""Provides easy read and write access to a git remote.
427434
@@ -552,8 +559,8 @@ def delete_url(self, url: str, **kwargs: Any) -> 'Remote':
552559
def urls(self) -> Iterator[str]:
553560
""":return: Iterator yielding all configured URL targets on a remote as strings"""
554561
try:
555-
# can replace cast with type assert?
556-
remote_details = cast(str, self.repo.git.remote("get-url", "--all", self.name))
562+
remote_details = self.repo.git.remote("get-url", "--all", self.name)
563+
assert isinstance(remote_details, str)
557564
for line in remote_details.split('\n'):
558565
yield line
559566
except GitCommandError as ex:
@@ -564,14 +571,16 @@ def urls(self) -> Iterator[str]:
564571
#
565572
if 'Unknown subcommand: get-url' in str(ex):
566573
try:
567-
remote_details = cast(str, self.repo.git.remote("show", self.name))
574+
remote_details = self.repo.git.remote("show", self.name)
575+
assert isinstance(remote_details, str)
568576
for line in remote_details.split('\n'):
569577
if ' Push URL:' in line:
570578
yield line.split(': ')[-1]
571579
except GitCommandError as _ex:
572580
if any(msg in str(_ex) for msg in ['correct access rights', 'cannot run ssh']):
573581
# If ssh is not setup to access this repository, see issue 694
574-
remote_details = cast(str, self.repo.git.config('--get-all', 'remote.%s.url' % self.name))
582+
remote_details = self.repo.git.config('--get-all', 'remote.%s.url' % self.name)
583+
assert isinstance(remote_details, str)
575584
for line in remote_details.split('\n'):
576585
yield line
577586
else:
@@ -580,18 +589,18 @@ def urls(self) -> Iterator[str]:
580589
raise ex
581590

582591
@property
583-
def refs(self) -> IterableList:
592+
def refs(self) -> IterableList[RemoteReference]:
584593
"""
585594
:return:
586595
IterableList of RemoteReference objects. It is prefixed, allowing
587596
you to omit the remote path portion, i.e.::
588597
remote.refs.master # yields RemoteReference('/refs/remotes/origin/master')"""
589-
out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name)
598+
out_refs: IterableList[RemoteReference] = IterableList(RemoteReference._id_attribute_, "%s/" % self.name)
590599
out_refs.extend(RemoteReference.list_items(self.repo, remote=self.name))
591600
return out_refs
592601

593602
@property
594-
def stale_refs(self) -> IterableList:
603+
def stale_refs(self) -> IterableList[Reference]:
595604
"""
596605
:return:
597606
IterableList RemoteReference objects that do not have a corresponding
@@ -606,7 +615,7 @@ def stale_refs(self) -> IterableList:
606615
as well. This is a fix for the issue described here:
607616
https://github.com/gitpython-developers/GitPython/issues/260
608617
"""
609-
out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name)
618+
out_refs: IterableList[RemoteReference] = IterableList(RemoteReference._id_attribute_, "%s/" % self.name)
610619
for line in self.repo.git.remote("prune", "--dry-run", self).splitlines()[2:]:
611620
# expecting
612621
# * [would prune] origin/new_branch
@@ -681,11 +690,12 @@ def update(self, **kwargs: Any) -> 'Remote':
681690
return self
682691

683692
def _get_fetch_info_from_stderr(self, proc: TBD,
684-
progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList:
693+
progress: Union[Callable[..., Any], RemoteProgress, None]
694+
) -> IterableList['FetchInfo']:
685695
progress = to_progress_instance(progress)
686696

687697
# skip first line as it is some remote info we are not interested in
688-
output = IterableList('name')
698+
output: IterableList['FetchInfo'] = IterableList('name')
689699

690700
# lines which are no progress are fetch info lines
691701
# this also waits for the command to finish
@@ -741,15 +751,15 @@ def _get_fetch_info_from_stderr(self, proc: TBD,
741751
return output
742752

743753
def _get_push_info(self, proc: TBD,
744-
progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList:
754+
progress: Union[Callable[..., Any], RemoteProgress, None]) -> IterableList[PushInfo]:
745755
progress = to_progress_instance(progress)
746756

747757
# read progress information from stderr
748758
# we hope stdout can hold all the data, it should ...
749759
# read the lines manually as it will use carriage returns between the messages
750760
# to override the previous one. This is why we read the bytes manually
751761
progress_handler = progress.new_message_handler()
752-
output = IterableList('push_infos')
762+
output: IterableList[PushInfo] = IterableList('push_infos')
753763

754764
def stdout_handler(line: str) -> None:
755765
try:
@@ -785,7 +795,7 @@ def _assert_refspec(self) -> None:
785795

786796
def fetch(self, refspec: Union[str, List[str], None] = None,
787797
progress: Union[Callable[..., Any], None] = None,
788-
verbose: bool = True, **kwargs: Any) -> IterableList:
798+
verbose: bool = True, **kwargs: Any) -> IterableList[FetchInfo]:
789799
"""Fetch the latest changes for this remote
790800
791801
:param refspec:
@@ -832,7 +842,7 @@ def fetch(self, refspec: Union[str, List[str], None] = None,
832842

833843
def pull(self, refspec: Union[str, List[str], None] = None,
834844
progress: Union[Callable[..., Any], None] = None,
835-
**kwargs: Any) -> IterableList:
845+
**kwargs: Any) -> IterableList[FetchInfo]:
836846
"""Pull changes from the given branch, being the same as a fetch followed
837847
by a merge of branch with your local branch.
838848
@@ -853,7 +863,7 @@ def pull(self, refspec: Union[str, List[str], None] = None,
853863

854864
def push(self, refspec: Union[str, List[str], None] = None,
855865
progress: Union[Callable[..., Any], None] = None,
856-
**kwargs: Any) -> IterableList:
866+
**kwargs: Any) -> IterableList[PushInfo]:
857867
"""Push changes from source branch in refspec to target branch in refspec.
858868
859869
:param refspec: see 'fetch' method

‎git/repo/base.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import re
99
import warnings
10+
from gitdb.db.loose import LooseObjectDB
1011

1112
from gitdb.exc import BadObject
1213

@@ -100,7 +101,7 @@ class Repo(object):
100101
# Subclasses may easily bring in their own custom types by placing a constructor or type here
101102
GitCommandWrapperType = Git
102103

103-
def __init__(self, path: Optional[PathLike] = None, odbt: Type[GitCmdObjectDB] = GitCmdObjectDB,
104+
def __init__(self, path: Optional[PathLike] = None, odbt: Type[LooseObjectDB] = GitCmdObjectDB,
104105
search_parent_directories: bool = False, expand_vars: bool = True) -> None:
105106
"""Create a new Repo instance
106107
@@ -308,15 +309,15 @@ def bare(self) -> bool:
308309
return self._bare
309310

310311
@property
311-
def heads(self) -> 'IterableList':
312+
def heads(self) -> 'IterableList[Head]':
312313
"""A list of ``Head`` objects representing the branch heads in
313314
this repo
314315
315316
:return: ``git.IterableList(Head, ...)``"""
316317
return Head.list_items(self)
317318

318319
@property
319-
def references(self) -> 'IterableList':
320+
def references(self) -> 'IterableList[Reference]':
320321
"""A list of Reference objects representing tags, heads and remote references.
321322
322323
:return: IterableList(Reference, ...)"""
@@ -341,7 +342,7 @@ def head(self) -> 'HEAD':
341342
return HEAD(self, 'HEAD')
342343

343344
@property
344-
def remotes(self) -> 'IterableList':
345+
def remotes(self) -> 'IterableList[Remote]':
345346
"""A list of Remote objects allowing to access and manipulate remotes
346347
:return: ``git.IterableList(Remote, ...)``"""
347348
return Remote.list_items(self)
@@ -357,13 +358,13 @@ def remote(self, name: str = 'origin') -> 'Remote':
357358
#{ Submodules
358359

359360
@property
360-
def submodules(self) -> 'IterableList':
361+
def submodules(self) -> 'IterableList[Submodule]':
361362
"""
362363
:return: git.IterableList(Submodule, ...) of direct submodules
363364
available from the current head"""
364365
return Submodule.list_items(self)
365366

366-
def submodule(self, name: str) -> 'IterableList':
367+
def submodule(self, name: str) -> 'Submodule':
367368
""" :return: Submodule with the given name
368369
:raise ValueError: If no such submodule exists"""
369370
try:
@@ -395,7 +396,7 @@ def submodule_update(self, *args: Any, **kwargs: Any) -> Iterator:
395396
#}END submodules
396397

397398
@property
398-
def tags(self) -> 'IterableList':
399+
def tags(self) -> 'IterableList[TagReference]':
399400
"""A list of ``Tag`` objects that are available in this repo
400401
:return: ``git.IterableList(TagReference, ...)`` """
401402
return TagReference.list_items(self)

‎git/types.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
import os
66
import sys
7-
from typing import Union, Any
7+
from typing import Dict, Union, Any
88

99
if sys.version_info[:2] >= (3, 8):
10-
from typing import Final, Literal, SupportsIndex # noqa: F401
10+
from typing import Final, Literal, SupportsIndex, TypedDict # noqa: F401
1111
else:
12-
from typing_extensions import Final, Literal, SupportsIndex # noqa: F401
12+
from typing_extensions import Final, Literal, SupportsIndex, TypedDict # noqa: F401
13+
14+
if sys.version_info[:2] >= (3, 10):
15+
from typing import TypeGuard # noqa: F401
16+
else:
17+
from typing_extensions import TypeGuard # noqa: F401
1318

1419

1520
if sys.version_info[:2] < (3, 9):
@@ -22,3 +27,21 @@
2227
TBD = Any
2328

2429
Lit_config_levels = Literal['system', 'global', 'user', 'repository']
30+
31+
32+
class Files_TD(TypedDict):
33+
insertions: int
34+
deletions: int
35+
lines: int
36+
37+
38+
class Total_TD(TypedDict):
39+
insertions: int
40+
deletions: int
41+
lines: int
42+
files: int
43+
44+
45+
class HSH_TD(TypedDict):
46+
total: Total_TD
47+
files: Dict[PathLike, Files_TD]

‎git/util.py

+77-19
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,21 @@
1818
import time
1919
from unittest import SkipTest
2020
from urllib.parse import urlsplit, urlunsplit
21+
import warnings
2122

2223
# typing ---------------------------------------------------------
2324

2425
from typing import (Any, AnyStr, BinaryIO, Callable, Dict, Generator, IO, Iterator, List,
25-
Optional, Pattern, Sequence, Tuple, Union, cast, TYPE_CHECKING, overload)
26+
Optional, Pattern, Sequence, Tuple, TypeVar, Union, cast, TYPE_CHECKING, overload)
2627

2728
import pathlib
2829

2930
if TYPE_CHECKING:
3031
from git.remote import Remote
3132
from git.repo.base import Repo
32-
from .types import PathLike, TBD, Literal, SupportsIndex
33+
from git.config import GitConfigParser, SectionConstraint
34+
35+
from .types import PathLike, Literal, SupportsIndex, HSH_TD, Files_TD
3336

3437
# ---------------------------------------------------------------------
3538

@@ -81,7 +84,7 @@ def unbare_repo(func: Callable) -> Callable:
8184
encounter a bare repository"""
8285

8386
@wraps(func)
84-
def wrapper(self: 'Remote', *args: Any, **kwargs: Any) -> TBD:
87+
def wrapper(self: 'Remote', *args: Any, **kwargs: Any) -> Callable:
8588
if self.repo.bare:
8689
raise InvalidGitRepositoryError("Method '%s' cannot operate on bare repositories" % func.__name__)
8790
# END bare method
@@ -107,7 +110,7 @@ def rmtree(path: PathLike) -> None:
107110
:note: we use shutil rmtree but adjust its behaviour to see whether files that
108111
couldn't be deleted are read-only. Windows will not remove them in that case"""
109112

110-
def onerror(func: Callable, path: PathLike, exc_info: TBD) -> None:
113+
def onerror(func: Callable, path: PathLike, exc_info: str) -> None:
111114
# Is the error an access error ?
112115
os.chmod(path, stat.S_IWUSR)
113116

@@ -447,7 +450,7 @@ class RemoteProgress(object):
447450
re_op_relative = re.compile(r"(remote: )?([\w\s]+):\s+(\d+)% \((\d+)/(\d+)\)(.*)")
448451

449452
def __init__(self) -> None:
450-
self._seen_ops = [] # type: List[TBD]
453+
self._seen_ops = [] # type: List[int]
451454
self._cur_line = None # type: Optional[str]
452455
self.error_lines = [] # type: List[str]
453456
self.other_lines = [] # type: List[str]
@@ -668,7 +671,8 @@ def _from_string(cls, string: str) -> 'Actor':
668671
# END handle name/email matching
669672

670673
@classmethod
671-
def _main_actor(cls, env_name: str, env_email: str, config_reader: Optional[TBD] = None) -> 'Actor':
674+
def _main_actor(cls, env_name: str, env_email: str,
675+
config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor':
672676
actor = Actor('', '')
673677
user_id = None # We use this to avoid multiple calls to getpass.getuser()
674678

@@ -697,7 +701,7 @@ def default_name() -> str:
697701
return actor
698702

699703
@classmethod
700-
def committer(cls, config_reader: Optional[TBD] = None) -> 'Actor':
704+
def committer(cls, config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor':
701705
"""
702706
:return: Actor instance corresponding to the configured committer. It behaves
703707
similar to the git implementation, such that the environment will override
@@ -708,7 +712,7 @@ def committer(cls, config_reader: Optional[TBD] = None) -> 'Actor':
708712
return cls._main_actor(cls.env_committer_name, cls.env_committer_email, config_reader)
709713

710714
@classmethod
711-
def author(cls, config_reader: Optional[TBD] = None) -> 'Actor':
715+
def author(cls, config_reader: Union[None, 'GitConfigParser', 'SectionConstraint'] = None) -> 'Actor':
712716
"""Same as committer(), but defines the main author. It may be specified in the environment,
713717
but defaults to the committer"""
714718
return cls._main_actor(cls.env_author_name, cls.env_author_email, config_reader)
@@ -742,7 +746,9 @@ class Stats(object):
742746
files = number of changed files as int"""
743747
__slots__ = ("total", "files")
744748

745-
def __init__(self, total: Dict[str, Dict[str, int]], files: Dict[str, Dict[str, int]]):
749+
from git.types import Total_TD, Files_TD
750+
751+
def __init__(self, total: Total_TD, files: Dict[PathLike, Files_TD]):
746752
self.total = total
747753
self.files = files
748754

@@ -751,9 +757,13 @@ def _list_from_string(cls, repo: 'Repo', text: str) -> 'Stats':
751757
"""Create a Stat object from output retrieved by git-diff.
752758
753759
:return: git.Stat"""
754-
hsh = {'total': {'insertions': 0, 'deletions': 0, 'lines': 0, 'files': 0},
755-
'files': {}
756-
} # type: Dict[str, Dict[str, TBD]] ## need typeddict or refactor for mypy
760+
761+
hsh: HSH_TD = {'total': {'insertions': 0,
762+
'deletions': 0,
763+
'lines': 0,
764+
'files': 0},
765+
'files': {}
766+
}
757767
for line in text.splitlines():
758768
(raw_insertions, raw_deletions, filename) = line.split("\t")
759769
insertions = raw_insertions != '-' and int(raw_insertions) or 0
@@ -762,9 +772,10 @@ def _list_from_string(cls, repo: 'Repo', text: str) -> 'Stats':
762772
hsh['total']['deletions'] += deletions
763773
hsh['total']['lines'] += insertions + deletions
764774
hsh['total']['files'] += 1
765-
hsh['files'][filename.strip()] = {'insertions': insertions,
766-
'deletions': deletions,
767-
'lines': insertions + deletions}
775+
files_dict: Files_TD = {'insertions': insertions,
776+
'deletions': deletions,
777+
'lines': insertions + deletions}
778+
hsh['files'][filename.strip()] = files_dict
768779
return Stats(hsh['total'], hsh['files'])
769780

770781

@@ -920,7 +931,10 @@ def _obtain_lock(self) -> None:
920931
# END endless loop
921932

922933

923-
class IterableList(list):
934+
T = TypeVar('T', bound='IterableObj')
935+
936+
937+
class IterableList(List[T]):
924938

925939
"""
926940
List of iterable objects allowing to query an object by id or by named index::
@@ -930,6 +944,9 @@ class IterableList(list):
930944
heads['master']
931945
heads[0]
932946
947+
Iterable parent objects = [Commit, SubModule, Reference, FetchInfo, PushInfo]
948+
Iterable via inheritance = [Head, TagReference, RemoteReference]
949+
]
933950
It requires an id_attribute name to be set which will be queried from its
934951
contained items to have a means for comparison.
935952
@@ -938,7 +955,7 @@ class IterableList(list):
938955
can be left out."""
939956
__slots__ = ('_id_attr', '_prefix')
940957

941-
def __new__(cls, id_attr: str, prefix: str = '') -> 'IterableList':
958+
def __new__(cls, id_attr: str, prefix: str = '') -> 'IterableList[IterableObj]':
942959
return super(IterableList, cls).__new__(cls)
943960

944961
def __init__(self, id_attr: str, prefix: str = '') -> None:
@@ -1007,16 +1024,29 @@ def __delitem__(self, index: Union[SupportsIndex, int, slice, str]) -> Any:
10071024
list.__delitem__(self, delindex)
10081025

10091026

1027+
class IterableClassWatcher(type):
1028+
def __init__(cls, name, bases, clsdict):
1029+
for base in bases:
1030+
if type(base) == IterableClassWatcher:
1031+
warnings.warn(f"GitPython Iterable subclassed by {name}. "
1032+
"Iterable is deprecated due to naming clash, "
1033+
"Use IterableObj instead \n",
1034+
DeprecationWarning,
1035+
stacklevel=2)
1036+
1037+
10101038
class Iterable(object):
10111039

10121040
"""Defines an interface for iterable items which is to assure a uniform
10131041
way to retrieve and iterate items within the git repository"""
10141042
__slots__ = ()
10151043
_id_attribute_ = "attribute that most suitably identifies your instance"
1044+
__metaclass__ = IterableClassWatcher
10161045

10171046
@classmethod
1018-
def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> 'IterableList':
1047+
def list_items(cls, repo, *args, **kwargs):
10191048
"""
1049+
Deprecaated, use IterableObj instead.
10201050
Find all items of this type - subclasses can specify args and kwargs differently.
10211051
If no args are given, subclasses are obliged to return all items if no additional
10221052
arguments arg given.
@@ -1029,7 +1059,35 @@ def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> 'IterableList':
10291059
return out_list
10301060

10311061
@classmethod
1032-
def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> Iterator[TBD]:
1062+
def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any):
1063+
# return typed to be compatible with subtypes e.g. Remote
1064+
"""For more information about the arguments, see list_items
1065+
:return: iterator yielding Items"""
1066+
raise NotImplementedError("To be implemented by Subclass")
1067+
1068+
1069+
class IterableObj():
1070+
"""Defines an interface for iterable items which is to assure a uniform
1071+
way to retrieve and iterate items within the git repository"""
1072+
__slots__ = ()
1073+
_id_attribute_ = "attribute that most suitably identifies your instance"
1074+
1075+
@classmethod
1076+
def list_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> IterableList[T]:
1077+
"""
1078+
Find all items of this type - subclasses can specify args and kwargs differently.
1079+
If no args are given, subclasses are obliged to return all items if no additional
1080+
arguments arg given.
1081+
1082+
:note: Favor the iter_items method as it will
1083+
1084+
:return:list(Item,...) list of item instances"""
1085+
out_list: IterableList = IterableList(cls._id_attribute_)
1086+
out_list.extend(cls.iter_items(repo, *args, **kwargs))
1087+
return out_list
1088+
1089+
@classmethod
1090+
def iter_items(cls, repo: 'Repo', *args: Any, **kwargs: Any) -> Iterator[T]:
10331091
# return typed to be compatible with subtypes e.g. Remote
10341092
"""For more information about the arguments, see list_items
10351093
:return: iterator yielding Items"""

‎requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
gitdb>=4.0.1,<5
2-
typing-extensions>=3.7.4.3;python_version<"3.8"
2+
typing-extensions>=3.7.4.3;python_version<"3.10"

‎test-requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ tox
55
virtualenv
66
nose
77
gitdb>=4.0.1,<5
8-
typing-extensions>=3.7.4.3;python_version<"3.8"
8+
typing-extensions>=3.7.4.3;python_version<"3.10"

0 commit comments

Comments
 (0)
Please sign in to comment.