Skip to content

Commit 93bc282

Browse files
committed
#71: fix expander.py for local includes
1 parent f5812bd commit 93bc282

File tree

1 file changed

+25
-18
lines changed

1 file changed

+25
-18
lines changed

expander.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414

1515
class Expander:
16+
local_include = re.compile(
17+
r'#include\s*"([a-z_]*(|.hpp))"\s*')
1618
atcoder_include = re.compile(
17-
'#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*')
19+
r'#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*')
1820

19-
include_guard = re.compile('#.*ATCODER_[A-Z_]*_HPP')
21+
include_guard = re.compile(r'#.*ATCODER_[A-Z_]*_HPP')
2022

2123
def is_ignored_line(self, line) -> bool:
2224
if self.include_guard.match(line):
@@ -30,27 +32,24 @@ def is_ignored_line(self, line) -> bool:
3032
def __init__(self, lib_paths: List[Path]):
3133
self.lib_paths = lib_paths
3234

33-
included = set() # type: Set[str]
35+
included = set() # type: Set[Path]
3436

35-
def find_acl(self, acl_name: str) -> Optional[Path]:
37+
def find_acl(self, acl_name: str) -> Path:
3638
for lib_path in self.lib_paths:
3739
path = lib_path / acl_name
3840
if path.exists():
3941
return path
40-
return None
42+
logger.error('cannot find: {}'.format(acl_name))
43+
raise FileNotFoundError()
4144

42-
def expand_acl(self, acl_name: str) -> List[str]:
43-
if acl_name in self.included:
44-
logger.info('already included: {}'.format(acl_name))
45+
def expand_acl(self, acl_file_path: Path) -> List[str]:
46+
if acl_file_path in self.included:
47+
logger.info('already included: {}'.format(acl_file_path.name))
4548
return []
46-
self.included.add(acl_name)
47-
logger.info('include: {}'.format(acl_name))
48-
acl_path = self.find_acl(acl_name)
49-
if not acl_path:
50-
logger.warning('cannot find: {}'.format(acl_name))
51-
raise FileNotFoundError()
49+
self.included.add(acl_file_path)
50+
logger.info('include: {}'.format(acl_file_path.name))
5251

53-
acl_source = open(str(acl_path)).read()
52+
acl_source = open(str(acl_file_path)).read()
5453

5554
result = [] # type: List[str]
5655
for line in acl_source.splitlines():
@@ -59,7 +58,14 @@ def expand_acl(self, acl_name: str) -> List[str]:
5958

6059
m = self.atcoder_include.match(line)
6160
if m:
62-
result.extend(self.expand_acl(m.group(1)))
61+
name = m.group(1)
62+
result.extend(self.expand_acl(self.find_acl(name)))
63+
continue
64+
65+
m = self.local_include.match(line)
66+
if m:
67+
name = m.group(1)
68+
result.extend(self.expand_acl(acl_file_path.parent / name))
6369
continue
6470

6571
result.append(line)
@@ -71,10 +77,11 @@ def expand(self, source: str) -> str:
7177
result = [] # type: List[str]
7278
for line in source.splitlines():
7379
m = self.atcoder_include.match(line)
74-
7580
if m:
76-
result.extend(self.expand_acl(m.group(1)))
81+
acl_path = self.find_acl(m.group(1))
82+
result.extend(self.expand_acl(acl_path))
7783
continue
84+
7885
result.append(line)
7986
return '\n'.join(result)
8087

0 commit comments

Comments
 (0)