diff --git a/git/index/base.py b/git/index/base.py index 1c56a219b..10f8b8b25 100644 --- a/git/index/base.py +++ b/git/index/base.py @@ -52,6 +52,7 @@ from .typ import ( BaseIndexEntry, IndexEntry, + StageType, ) from .util import TemporaryFileSwap, post_clear_cache, default_index, git_working_dir @@ -83,13 +84,12 @@ from git.util import Actor -StageType = int Treeish = Union[Tree, Commit, str, bytes] # ------------------------------------------------------------------------------------ -__all__ = ("IndexFile", "CheckoutError") +__all__ = ("IndexFile", "CheckoutError", "StageType") class IndexFile(LazyMixin, git_diff.Diffable, Serializable): diff --git a/git/index/typ.py b/git/index/typ.py index 6371953bb..b2c6c371b 100644 --- a/git/index/typ.py +++ b/git/index/typ.py @@ -1,6 +1,7 @@ """Module with additional types used by the index""" from binascii import b2a_hex +from pathlib import Path from .util import pack, unpack from git.objects import Blob @@ -8,16 +9,18 @@ # typing ---------------------------------------------------------------------- -from typing import NamedTuple, Sequence, TYPE_CHECKING, Tuple, Union, cast +from typing import NamedTuple, Sequence, TYPE_CHECKING, Tuple, Union, cast, List from git.types import PathLike if TYPE_CHECKING: from git.repo import Repo +StageType = int + # --------------------------------------------------------------------------------- -__all__ = ("BlobFilter", "BaseIndexEntry", "IndexEntry") +__all__ = ("BlobFilter", "BaseIndexEntry", "IndexEntry", "StageType") # { Invariants CE_NAMEMASK = 0x0FFF @@ -48,12 +51,18 @@ def __init__(self, paths: Sequence[PathLike]) -> None: """ self.paths = paths - def __call__(self, stage_blob: Blob) -> bool: - path = stage_blob[1].path - for p in self.paths: - if path.startswith(p): + def __call__(self, stage_blob: Tuple[StageType, Blob]) -> bool: + blob_pathlike: PathLike = stage_blob[1].path + blob_path: Path = blob_pathlike if isinstance(blob_pathlike, Path) else Path(blob_pathlike) + for pathlike in self.paths: + path: Path = pathlike if isinstance(pathlike, Path) else Path(pathlike) + # TODO: Change to use `PosixPath.is_relative_to` once Python 3.8 is no longer supported. + filter_parts: List[str] = path.parts + blob_parts: List[str] = blob_path.parts + if len(filter_parts) > len(blob_parts): + continue + if all(i == j for i, j in zip(filter_parts, blob_parts)): return True - # END for each path in filter paths return False diff --git a/test/test_blob_filter.py b/test/test_blob_filter.py new file mode 100644 index 000000000..cbaa30b8b --- /dev/null +++ b/test/test_blob_filter.py @@ -0,0 +1,32 @@ +"""Test the blob filter.""" +from pathlib import Path +from typing import Sequence, Tuple +from unittest.mock import MagicMock + +import pytest + +from git.index.typ import BlobFilter, StageType +from git.objects import Blob +from git.types import PathLike + + +# fmt: off +@pytest.mark.parametrize('paths, path, expected_result', [ + ((Path("foo"),), Path("foo"), True), + ((Path("foo"),), Path("foo/bar"), True), + ((Path("foo/bar"),), Path("foo"), False), + ((Path("foo"), Path("bar")), Path("foo"), True), +]) +# fmt: on +def test_blob_filter(paths: Sequence[PathLike], path: PathLike, expected_result: bool) -> None: + """Test the blob filter.""" + blob_filter = BlobFilter(paths) + + binsha = MagicMock(__len__=lambda self: 20) + stage_type: StageType = 0 + blob: Blob = Blob(repo=MagicMock(), binsha=binsha, path=path) + stage_blob: Tuple[StageType, Blob] = (stage_type, blob) + + result = blob_filter(stage_blob) + + assert result == expected_result