diff --git a/git/objects/commit.py b/git/objects/commit.py index 302798cb2..45e6d772c 100644 --- a/git/objects/commit.py +++ b/git/objects/commit.py @@ -136,6 +136,41 @@ def __init__(self, repo, binsha, tree=None, author=None, authored_date=None, aut def _get_intermediate_items(cls, commit): return commit.parents + @classmethod + def _calculate_sha_(cls, repo, commit): + '''Calculate the sha of a commit. + + :param repo: Repo object the commit should be part of + :param commit: Commit object for which to generate the sha + ''' + + stream = BytesIO() + commit._serialize(stream) + streamlen = stream.tell() + stream.seek(0) + + istream = repo.odb.store(IStream(cls.type, streamlen, stream)) + return istream.binsha + + def replace(self, **kwargs): + '''Create new commit object from existing commit object. + + Any values provided as keyword arguments will replace the + corresponding attribute in the new object. + ''' + + attrs = {k: getattr(self, k) for k in self.__slots__} + + for attrname in kwargs: + if attrname not in self.__slots__: + raise ValueError('invalid attribute name') + + attrs.update(kwargs) + new_commit = self.__class__(self.repo, self.NULL_BIN_SHA, **attrs) + new_commit.binsha = self._calculate_sha_(self.repo, new_commit) + + return new_commit + def _set_cache_(self, attr): if attr in Commit.__slots__: # read the data in a chunk, its faster - then provide a file wrapper @@ -375,13 +410,7 @@ def create_from_tree(cls, repo, tree, message, parent_commits=None, head=False, committer, committer_time, committer_offset, message, parent_commits, conf_encoding) - stream = BytesIO() - new_commit._serialize(stream) - streamlen = stream.tell() - stream.seek(0) - - istream = repo.odb.store(IStream(cls.type, streamlen, stream)) - new_commit.binsha = istream.binsha + new_commit.binsha = cls._calculate_sha_(repo, new_commit) if head: # need late import here, importing git at the very beginning throws diff --git a/test/test_commit.py b/test/test_commit.py index 7260a2337..2fe80530d 100644 --- a/test/test_commit.py +++ b/test/test_commit.py @@ -101,6 +101,26 @@ def test_bake(self): assert isinstance(commit.author_tz_offset, int) and isinstance(commit.committer_tz_offset, int) self.assertEqual(commit.message, "Added missing information to docstrings of commit and stats module\n") + def test_replace_no_changes(self): + old_commit = self.rorepo.commit('2454ae89983a4496a445ce347d7a41c0bb0ea7ae') + new_commit = old_commit.replace() + + for attr in old_commit.__slots__: + assert getattr(new_commit, attr) == getattr(old_commit, attr) + + def test_replace_new_sha(self): + commit = self.rorepo.commit('2454ae89983a4496a445ce347d7a41c0bb0ea7ae') + new_commit = commit.replace(message='Added replace method') + + assert new_commit.hexsha == 'fc84cbecac1bd4ba4deaac07c1044889edd536e6' + assert new_commit.message == 'Added replace method' + + def test_replace_invalid_attribute(self): + commit = self.rorepo.commit('2454ae89983a4496a445ce347d7a41c0bb0ea7ae') + + with self.assertRaises(ValueError): + commit.replace(badattr='This will never work') + def test_stats(self): commit = self.rorepo.commit('33ebe7acec14b25c5f84f35a664803fcab2f7781') stats = commit.stats