Skip to content

Commit 9ee1f41

Browse files
authored
Merge pull request #48 from atcoder/patch/refactor
refactor expander
2 parents f7e193c + da1fec9 commit 9ee1f41

File tree

3 files changed

+73
-46
lines changed

3 files changed

+73
-46
lines changed

expander.py

+62-39
Original file line numberDiff line numberDiff line change
@@ -6,40 +6,69 @@
66
from logging import Logger, basicConfig, getLogger
77
from os import getenv, environ
88
from pathlib import Path
9-
from typing import List
9+
from typing import List, Set, Optional
1010

1111

1212
logger = getLogger(__name__) # type: Logger
1313

14-
atcoder_include = re.compile('#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*')
1514

16-
include_guard = re.compile('#.*ATCODER_[A-Z_]*_HPP')
17-
18-
lib_path = Path.cwd()
19-
20-
defined = set()
21-
22-
def dfs(f: str) -> List[str]:
23-
global defined
24-
if f in defined:
25-
logger.info('already included {}, skip'.format(f))
26-
return []
27-
defined.add(f)
28-
29-
logger.info('include {}'.format(f))
30-
31-
s = open(str(lib_path / f)).read()
32-
result = []
33-
for line in s.splitlines():
34-
if include_guard.match(line):
35-
continue
36-
37-
m = atcoder_include.match(line)
38-
if m:
39-
result.extend(dfs(m.group(1)))
40-
continue
41-
result.append(line)
42-
return result
15+
class Expander:
16+
atcoder_include = re.compile(
17+
'#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*')
18+
19+
include_guard = re.compile('#.*ATCODER_[A-Z_]*_HPP')
20+
21+
def __init__(self, lib_paths: List[Path] = None):
22+
if lib_paths:
23+
self.lib_paths = lib_paths
24+
else:
25+
self.lib_paths = [Path.cwd()]
26+
27+
included = set() # type: Set[str]
28+
29+
def find_acl(self, acl_name: str) -> Optional[Path]:
30+
for lib_path in self.lib_paths:
31+
path = lib_path / acl_name
32+
if path.exists():
33+
return path
34+
return None
35+
36+
def expand_acl(self, acl_name: str) -> List[str]:
37+
if acl_name in self.included:
38+
logger.info('already included: {}'.format(acl_name))
39+
return []
40+
self.included.add(acl_name)
41+
logger.info('include: {}'.format(acl_name))
42+
acl_path = self.find_acl(acl_name)
43+
if not acl_path:
44+
logger.warning('cannot find: {}'.format(acl_name))
45+
raise FileNotFoundError()
46+
47+
acl_source = open(str(acl_path)).read()
48+
49+
result = [] # type: List[str]
50+
for line in acl_source.splitlines():
51+
if self.include_guard.match(line):
52+
continue
53+
54+
m = self.atcoder_include.match(line)
55+
if m:
56+
result.extend(self.expand_acl(m.group(1)))
57+
continue
58+
result.append(line)
59+
return result
60+
61+
def expand(self, source: str) -> str:
62+
self.included = set()
63+
result = [] # type: List[str]
64+
for line in source.splitlines():
65+
m = self.atcoder_include.match(line)
66+
67+
if m:
68+
result.extend(self.expand_acl(m.group(1)))
69+
continue
70+
result.append(line)
71+
return '\n'.join(result)
4372

4473

4574
if __name__ == "__main__":
@@ -55,22 +84,16 @@ def dfs(f: str) -> List[str]:
5584
parser.add_argument('--lib', help='Path to Atcoder Library')
5685
opts = parser.parse_args()
5786

87+
lib_path = Path.cwd()
5888
if opts.lib:
5989
lib_path = Path(opts.lib)
6090
elif 'CPLUS_INCLUDE_PATH' in environ:
6191
lib_path = Path(environ['CPLUS_INCLUDE_PATH'])
62-
s = open(opts.source).read()
63-
64-
result = []
65-
for line in s.splitlines():
66-
m = atcoder_include.match(line)
6792

68-
if m:
69-
result.extend(dfs(m.group(1)))
70-
continue
71-
result.append(line)
93+
expander = Expander([lib_path])
94+
source = open(opts.source).read()
95+
output = expander.expand(source)
7296

73-
output = '\n'.join(result) + '\n'
7497
if opts.console:
7598
print(output)
7699
else:
File renamed without changes.

test/test_expander.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,34 @@
1313

1414

1515
class Test(unittest.TestCase):
16-
def compile_test(self, source: Path, env=None):
16+
def compile_test(self, source: Path, expander_args=[], env=None):
1717
if not env:
1818
env = environ.copy()
1919
lib_dir = Path.cwd().resolve()
2020
expander_path = Path('expander.py').resolve()
2121
with TemporaryDirectory() as new_dir:
2222
tmp = Path(new_dir)
2323
proc = run(['python', str(expander_path), str(
24-
source.resolve()), '--lib', str(lib_dir)], cwd=str(tmp), env=env)
24+
source.resolve())] + expander_args, cwd=str(tmp), env=env)
2525
self.assertEqual(proc.returncode, 0)
2626
proc = run(['g++', 'combined.cpp', '-std=c++14'], cwd=str(tmp))
2727
self.assertEqual(proc.returncode, 0)
2828

29-
def test_unionfind(self):
30-
self.compile_test(Path('test/expander/include_unionfind.cpp'))
29+
def test_dsu(self):
30+
self.compile_test(Path('test/expander/include_dsu.cpp'),
31+
expander_args=['--lib', str(Path.cwd().resolve())])
3132

3233
def test_unusual_format(self):
33-
self.compile_test(Path('test/expander/include_unusual_format.cpp'))
34+
self.compile_test(Path('test/expander/include_unusual_format.cpp'),
35+
expander_args=['--lib', str(Path.cwd().resolve())])
3436

3537
def test_all(self):
36-
self.compile_test(Path('test/expander/include_all.cpp'))
38+
self.compile_test(Path('test/expander/include_all.cpp'),
39+
expander_args=['--lib', str(Path.cwd().resolve())])
3740

3841
def test_comment_out(self):
39-
self.compile_test(Path('test/expander/comment_out.cpp'))
42+
self.compile_test(Path('test/expander/comment_out.cpp'),
43+
expander_args=['--lib', str(Path.cwd().resolve())])
4044

4145
def test_env_value(self):
4246
env = environ.copy()

0 commit comments

Comments
 (0)