Skip to content

Commit f98cb3b

Browse files
committed
Refactor violation diffing into separate function
ghstack-source-id: b91b636f15102432b2fa447e1dd1453a60308842 Pull Request resolved: #399
1 parent 2efadd0 commit f98cb3b

File tree

4 files changed

+85
-19
lines changed

4 files changed

+85
-19
lines changed

Diff for: src/fixit/engine.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@
2828
LOG = logging.getLogger(__name__)
2929

3030

31+
def diff_violation(path: Path, module: Module, violation: LintViolation) -> str:
32+
"""
33+
Generate string diff representation of a violation.
34+
"""
35+
36+
orig = module.code
37+
mod = module.deep_replace( # type:ignore # LibCST#906
38+
violation.node, violation.replacement
39+
)
40+
assert isinstance(mod, Module)
41+
change = mod.code
42+
43+
return unified_diff(
44+
orig,
45+
change,
46+
path.name,
47+
n=1,
48+
)
49+
50+
3151
class LintRunner:
3252
def __init__(self, path: Path, source: FileContent) -> None:
3353
self.path = path
@@ -87,19 +107,7 @@ def visit_hook(name: str) -> Iterator[None]:
87107
count += 1
88108

89109
if violation.replacement:
90-
orig = self.module.code
91-
mod = self.module.deep_replace( # type:ignore # LibCST#906
92-
violation.node, violation.replacement
93-
)
94-
assert isinstance(mod, Module)
95-
change = mod.code
96-
97-
diff = unified_diff(
98-
orig,
99-
change,
100-
self.path.name,
101-
n=1,
102-
)
110+
diff = diff_violation(self.path, self.module, violation)
103111
violation = replace(violation, diff=diff)
104112

105113
yield violation

Diff for: src/fixit/testing.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from pathlib import Path
1111
from typing import Any, Callable, Collection, Dict, List, Mapping, Sequence, Type, Union
1212

13-
from moreorless import unified_diff
14-
15-
from .engine import LintRunner
13+
from .engine import diff_violation, LintRunner
1614
from .ftypes import Config
1715
from .rule import Invalid, LintRule, Valid
1816

@@ -112,9 +110,7 @@ def _test_method(
112110

113111
if len(reports) == 1:
114112
# make sure we generated a reasonable diff
115-
expected_diff = unified_diff(
116-
source_code, expected_code, filename=path.name, n=1
117-
)
113+
expected_diff = diff_violation(path, runner.module, reports[0])
118114
self.assertEqual(expected_diff, report.diff)
119115

120116

Diff for: src/fixit/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from fixit.testing import add_lint_rule_tests_to_module
1010
from .config import ConfigTest
11+
from .engine import EngineTest
1112
from .ftypes import TypesTest
1213
from .rule import RuleTest, RunnerTest
1314
from .smoke import SmokeTest

Diff for: src/fixit/tests/engine.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from pathlib import Path
7+
from textwrap import dedent
8+
from unittest import TestCase
9+
10+
from libcst import (
11+
Call,
12+
ensure_type,
13+
Expr,
14+
parse_module,
15+
SimpleStatementLine,
16+
SimpleString,
17+
)
18+
from libcst.metadata import CodePosition, CodeRange
19+
20+
from ..engine import diff_violation
21+
from ..ftypes import LintViolation
22+
23+
24+
class EngineTest(TestCase):
25+
def test_diff_violation(self):
26+
src = dedent(
27+
"""\
28+
import sys
29+
print("hello world")
30+
"""
31+
)
32+
path = Path("foo.py")
33+
module = parse_module(src)
34+
node = ensure_type(
35+
ensure_type(
36+
ensure_type(module.body[-1], SimpleStatementLine).body[0], Expr
37+
).value,
38+
Call,
39+
).args[0]
40+
repl = node.with_changes(value=SimpleString('"goodnight moon"'))
41+
42+
violation = LintViolation(
43+
"Fake",
44+
CodeRange(CodePosition(1, 1), CodePosition(2, 2)),
45+
message="some error",
46+
node=node,
47+
replacement=repl,
48+
)
49+
50+
expected = dedent(
51+
"""\
52+
--- a/foo.py
53+
+++ b/foo.py
54+
@@ -1,2 +1,2 @@
55+
import sys
56+
-print("hello world")
57+
+print("goodnight moon")
58+
"""
59+
)
60+
result = diff_violation(path, module, violation)
61+
self.assertEqual(expected, result)

0 commit comments

Comments
 (0)