From 46d3d0520d8877b1979e6385bada9d9e1e0731ec Mon Sep 17 00:00:00 2001
From: Facundo Tuesca <facundo.tuesca@trailofbits.com>
Date: Thu, 21 Sep 2023 12:00:59 +0200
Subject: [PATCH] Add more checks for the validity of refnames

This change adds checks based on the rules described in [0] in
order to more robustly check a refname's validity.

[0]: https://git-scm.com/docs/git-check-ref-format
---
 git/refs/symbolic.py | 50 ++++++++++++++++++++++++++++++++++++++++++--
 test/test_refs.py    | 36 +++++++++++++++++++++++++++++++
 2 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/git/refs/symbolic.py b/git/refs/symbolic.py
index 5c293aa7b..819615103 100644
--- a/git/refs/symbolic.py
+++ b/git/refs/symbolic.py
@@ -161,6 +161,51 @@ def dereference_recursive(cls, repo: "Repo", ref_path: Union[PathLike, None]) ->
                 return hexsha
         # END recursive dereferencing
 
+    @staticmethod
+    def _check_ref_name_valid(ref_path: PathLike) -> None:
+        # Based on the rules described in https://git-scm.com/docs/git-check-ref-format/#_description
+        previous: Union[str, None] = None
+        one_before_previous: Union[str, None] = None
+        for c in str(ref_path):
+            if c in " ~^:?*[\\":
+                raise ValueError(
+                    f"Invalid reference '{ref_path}': references cannot contain spaces, tildes (~), carets (^),"
+                    f" colons (:), question marks (?), asterisks (*), open brackets ([) or backslashes (\\)"
+                )
+            elif c == ".":
+                if previous is None or previous == "/":
+                    raise ValueError(
+                        f"Invalid reference '{ref_path}': references cannot start with a period (.) or contain '/.'"
+                    )
+                elif previous == ".":
+                    raise ValueError(f"Invalid reference '{ref_path}': references cannot contain '..'")
+            elif c == "/":
+                if previous == "/":
+                    raise ValueError(f"Invalid reference '{ref_path}': references cannot contain '//'")
+                elif previous is None:
+                    raise ValueError(
+                        f"Invalid reference '{ref_path}': references cannot start with forward slashes '/'"
+                    )
+            elif c == "{" and previous == "@":
+                raise ValueError(f"Invalid reference '{ref_path}': references cannot contain '@{{'")
+            elif ord(c) < 32 or ord(c) == 127:
+                raise ValueError(f"Invalid reference '{ref_path}': references cannot contain ASCII control characters")
+
+            one_before_previous = previous
+            previous = c
+
+        if previous == ".":
+            raise ValueError(f"Invalid reference '{ref_path}': references cannot end with a period (.)")
+        elif previous == "/":
+            raise ValueError(f"Invalid reference '{ref_path}': references cannot end with a forward slash (/)")
+        elif previous == "@" and one_before_previous is None:
+            raise ValueError(f"Invalid reference '{ref_path}': references cannot be '@'")
+        elif any([component.endswith(".lock") for component in str(ref_path).split("/")]):
+            raise ValueError(
+                f"Invalid reference '{ref_path}': references cannot have slash-separated components that end with"
+                f" '.lock'"
+            )
+
     @classmethod
     def _get_ref_info_helper(
         cls, repo: "Repo", ref_path: Union[PathLike, None]
@@ -168,8 +213,9 @@ def _get_ref_info_helper(
         """Return: (str(sha), str(target_ref_path)) if available, the sha the file at
         rela_path points to, or None. target_ref_path is the reference we
         point to, or None"""
-        if ".." in str(ref_path):
-            raise ValueError(f"Invalid reference '{ref_path}'")
+        if ref_path:
+            cls._check_ref_name_valid(ref_path)
+
         tokens: Union[None, List[str], Tuple[str, str]] = None
         repodir = _git_dir(repo, ref_path)
         try:
diff --git a/test/test_refs.py b/test/test_refs.py
index afd273df9..80166f651 100644
--- a/test/test_refs.py
+++ b/test/test_refs.py
@@ -631,3 +631,39 @@ def test_refs_outside_repo(self):
             ref_file.flush()
             ref_file_name = Path(ref_file.name).name
             self.assertRaises(BadName, self.rorepo.commit, f"../../{ref_file_name}")
+
+    def test_validity_ref_names(self):
+        check_ref = SymbolicReference._check_ref_name_valid
+        # Based on the rules specified in https://git-scm.com/docs/git-check-ref-format/#_description
+        # Rule 1
+        self.assertRaises(ValueError, check_ref, ".ref/begins/with/dot")
+        self.assertRaises(ValueError, check_ref, "ref/component/.begins/with/dot")
+        self.assertRaises(ValueError, check_ref, "ref/ends/with/a.lock")
+        self.assertRaises(ValueError, check_ref, "ref/component/ends.lock/with/period_lock")
+        # Rule 2
+        check_ref("valid_one_level_refname")
+        # Rule 3
+        self.assertRaises(ValueError, check_ref, "ref/contains/../double/period")
+        # Rule 4
+        for c in " ~^:":
+            self.assertRaises(ValueError, check_ref, f"ref/contains/invalid{c}/character")
+        for code in range(0, 32):
+            self.assertRaises(ValueError, check_ref, f"ref/contains/invalid{chr(code)}/ASCII/control_character")
+        self.assertRaises(ValueError, check_ref, f"ref/contains/invalid{chr(127)}/ASCII/control_character")
+        # Rule 5
+        for c in "*?[":
+            self.assertRaises(ValueError, check_ref, f"ref/contains/invalid{c}/character")
+        # Rule 6
+        self.assertRaises(ValueError, check_ref, "/ref/begins/with/slash")
+        self.assertRaises(ValueError, check_ref, "ref/ends/with/slash/")
+        self.assertRaises(ValueError, check_ref, "ref/contains//double/slash/")
+        # Rule 7
+        self.assertRaises(ValueError, check_ref, "ref/ends/with/dot.")
+        # Rule 8
+        self.assertRaises(ValueError, check_ref, "ref/contains@{/at_brace")
+        # Rule 9
+        self.assertRaises(ValueError, check_ref, "@")
+        # Rule 10
+        self.assertRaises(ValueError, check_ref, "ref/contain\\s/backslash")
+        # Valid reference name should not raise
+        check_ref("valid/ref/name")