Skip to content

Commit 5402a16

Browse files
committed
Add types to objects _get_intermediate_items()
1 parent c242b55 commit 5402a16

File tree

4 files changed

+68
-22
lines changed

4 files changed

+68
-22
lines changed

git/objects/commit.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This module is part of GitPython and is released under
55
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
66

7+
from typing import Tuple, Union
78
from gitdb import IStream
89
from git.util import (
910
hex_to_bin,
@@ -70,7 +71,8 @@ class Commit(base.Object, Iterable, Diffable, Traversable, Serializable):
7071

7172
def __init__(self, repo, binsha, tree=None, author=None, authored_date=None, author_tz_offset=None,
7273
committer=None, committed_date=None, committer_tz_offset=None,
73-
message=None, parents=None, encoding=None, gpgsig=None):
74+
message=None, parents: Union[Tuple['Commit', ...], None] = None,
75+
encoding=None, gpgsig=None):
7476
"""Instantiate a new Commit. All keyword arguments taking None as default will
7577
be implicitly set on first query.
7678
@@ -133,7 +135,7 @@ def __init__(self, repo, binsha, tree=None, author=None, authored_date=None, aut
133135
self.gpgsig = gpgsig
134136

135137
@classmethod
136-
def _get_intermediate_items(cls, commit):
138+
def _get_intermediate_items(cls, commit: 'Commit') -> Tuple['Commit', ...]: # type: ignore
137139
return commit.parents
138140

139141
@classmethod
@@ -477,17 +479,17 @@ def _deserialize(self, stream):
477479
readline = stream.readline
478480
self.tree = Tree(self.repo, hex_to_bin(readline().split()[1]), Tree.tree_id << 12, '')
479481

480-
self.parents = []
482+
self.parents_list = [] # List['Commit']
481483
next_line = None
482484
while True:
483485
parent_line = readline()
484486
if not parent_line.startswith(b'parent'):
485487
next_line = parent_line
486488
break
487489
# END abort reading parents
488-
self.parents.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode('ascii'))))
490+
self.parents_list.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode('ascii'))))
489491
# END for each parent line
490-
self.parents = tuple(self.parents)
492+
self.parents = tuple(self.parents_list) # type: Tuple['Commit', ...]
491493

492494
# we don't know actual author encoding before we have parsed it, so keep the lines around
493495
author_line = next_line

git/objects/submodule/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import stat
6+
from typing import List
67
from unittest import SkipTest
78
import uuid
89

@@ -134,10 +135,11 @@ def _set_cache_(self, attr):
134135
super(Submodule, self)._set_cache_(attr)
135136
# END handle attribute name
136137

137-
def _get_intermediate_items(self, item):
138+
@classmethod
139+
def _get_intermediate_items(cls, item: 'Submodule') -> List['Submodule']: # type: ignore
138140
""":return: all the submodules of our module repository"""
139141
try:
140-
return type(self).list_items(item.module())
142+
return cls.list_items(item.module())
141143
except InvalidGitRepositoryError:
142144
return []
143145
# END handle intermediate items

git/objects/tree.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This module is part of GitPython and is released under
55
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
6+
from typing import Iterable, Iterator, Tuple, Union, cast
67
from git.util import join_path
78
import git.diff as diff
89
from git.util import to_bin_sha
@@ -182,8 +183,10 @@ def __init__(self, repo, binsha, mode=tree_id << 12, path=None):
182183
super(Tree, self).__init__(repo, binsha, mode, path)
183184

184185
@classmethod
185-
def _get_intermediate_items(cls, index_object):
186+
def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore
187+
) -> Tuple['Tree', ...]:
186188
if index_object.type == "tree":
189+
index_object = cast('Tree', index_object)
187190
return tuple(index_object._iter_convert_to_object(index_object._cache))
188191
return ()
189192

@@ -196,7 +199,8 @@ def _set_cache_(self, attr):
196199
super(Tree, self)._set_cache_(attr)
197200
# END handle attribute
198201

199-
def _iter_convert_to_object(self, iterable):
202+
def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]]
203+
) -> Iterator[Union[Blob, 'Tree', Submodule]]:
200204
"""Iterable yields tuples of (binsha, mode, name), which will be converted
201205
to the respective object representation"""
202206
for binsha, mode, name in iterable:

git/objects/util.py

+51-13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This module is part of GitPython and is released under
55
# the BSD License: http://www.opensource.org/licenses/bsd-license.php
66
"""Module for general utility functions"""
7+
8+
79
from git.util import (
810
IterableList,
911
Actor
@@ -18,9 +20,10 @@
1820
from datetime import datetime, timedelta, tzinfo
1921

2022
# typing ------------------------------------------------------------
21-
from typing import Any, IO, TYPE_CHECKING, Tuple, Type, Union, cast
23+
from typing import Any, Callable, IO, Iterator, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast, overload
2224

2325
if TYPE_CHECKING:
26+
from .submodule.base import Submodule
2427
from .commit import Commit
2528
from .blob import Blob
2629
from .tag import TagObject
@@ -115,7 +118,7 @@ def verify_utctz(offset: str) -> str:
115118

116119

117120
class tzoffset(tzinfo):
118-
121+
119122
def __init__(self, secs_west_of_utc: float, name: Union[None, str] = None) -> None:
120123
self._offset = timedelta(seconds=-secs_west_of_utc)
121124
self._name = name or 'fixed'
@@ -275,29 +278,61 @@ class Traversable(object):
275278
"""Simple interface to perform depth-first or breadth-first traversals
276279
into one direction.
277280
Subclasses only need to implement one function.
278-
Instances of the Subclass must be hashable"""
281+
Instances of the Subclass must be hashable
282+
283+
Defined subclasses = [Commit, Tree, SubModule]
284+
"""
279285
__slots__ = ()
280286

287+
@overload
288+
@classmethod
289+
def _get_intermediate_items(cls, item: 'Commit') -> Tuple['Commit', ...]:
290+
...
291+
292+
@overload
281293
@classmethod
282-
def _get_intermediate_items(cls, item):
294+
def _get_intermediate_items(cls, item: 'Submodule') -> Tuple['Submodule', ...]:
295+
...
296+
297+
@overload
298+
@classmethod
299+
def _get_intermediate_items(cls, item: 'Tree') -> Tuple['Tree', ...]:
300+
...
301+
302+
@overload
303+
@classmethod
304+
def _get_intermediate_items(cls, item: 'Traversable') -> Tuple['Traversable', ...]:
305+
...
306+
307+
@classmethod
308+
def _get_intermediate_items(cls, item: 'Traversable'
309+
) -> Sequence['Traversable']:
283310
"""
284311
Returns:
285-
List of items connected to the given item.
312+
Tuple of items connected to the given item.
286313
Must be implemented in subclass
314+
315+
class Commit:: (cls, Commit) -> Tuple[Commit, ...]
316+
class Submodule:: (cls, Submodule) -> Iterablelist[Submodule]
317+
class Tree:: (cls, Tree) -> Tuple[Tree, ...]
287318
"""
288319
raise NotImplementedError("To be implemented in subclass")
289320

290-
def list_traverse(self, *args, **kwargs):
321+
def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList:
291322
"""
292323
:return: IterableList with the results of the traversal as produced by
293324
traverse()"""
294-
out = IterableList(self._id_attribute_)
325+
out = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses
295326
out.extend(self.traverse(*args, **kwargs))
296327
return out
297328

298-
def traverse(self, predicate=lambda i, d: True,
299-
prune=lambda i, d: False, depth=-1, branch_first=True,
300-
visit_once=True, ignore_self=1, as_edge=False):
329+
def traverse(self,
330+
predicate: Callable[[object, int], bool] = lambda i, d: True,
331+
prune: Callable[[object, int], bool] = lambda i, d: False,
332+
depth: int = -1,
333+
branch_first: bool = True,
334+
visit_once: bool = True, ignore_self: int = 1, as_edge: bool = False
335+
) -> Union[Iterator['Traversable'], Iterator[Tuple['Traversable', 'Traversable']]]:
301336
""":return: iterator yielding of items found when traversing self
302337
303338
:param predicate: f(i,d) returns False if item i at depth d should not be included in the result
@@ -329,13 +364,16 @@ def traverse(self, predicate=lambda i, d: True,
329364
destination, i.e. tuple(src, dest) with the edge spanning from
330365
source to destination"""
331366
visited = set()
332-
stack = Deque()
367+
stack = Deque() # type: Deque[Tuple[int, Traversable, Union[Traversable, None]]]
333368
stack.append((0, self, None)) # self is always depth level 0
334369

335-
def addToStack(stack, item, branch_first, depth):
370+
def addToStack(stack: Deque[Tuple[int, 'Traversable', Union['Traversable', None]]],
371+
item: 'Traversable',
372+
branch_first: bool,
373+
depth) -> None:
336374
lst = self._get_intermediate_items(item)
337375
if not lst:
338-
return
376+
return None
339377
if branch_first:
340378
stack.extendleft((depth, i, item) for i in lst)
341379
else:

0 commit comments

Comments
 (0)