13
13
14
14
15
15
class Expander :
16
+ local_include = re .compile (
17
+ r'#include\s*"([a-z_]*(|.hpp))"\s*' )
16
18
atcoder_include = re .compile (
17
- '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
19
+ r '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
18
20
19
- include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
21
+ include_guard = re .compile (r '#.*ATCODER_[A-Z_]*_HPP' )
20
22
21
23
def is_ignored_line (self , line ) -> bool :
22
24
if self .include_guard .match (line ):
@@ -30,27 +32,24 @@ def is_ignored_line(self, line) -> bool:
30
32
def __init__ (self , lib_paths : List [Path ]):
31
33
self .lib_paths = lib_paths
32
34
33
- included = set () # type: Set[str ]
35
+ included = set () # type: Set[Path ]
34
36
35
- def find_acl (self , acl_name : str ) -> Optional [ Path ] :
37
+ def find_acl (self , acl_name : str ) -> Path :
36
38
for lib_path in self .lib_paths :
37
39
path = lib_path / acl_name
38
40
if path .exists ():
39
41
return path
40
- return None
42
+ logger .error ('cannot find: {}' .format (acl_name ))
43
+ raise FileNotFoundError ()
41
44
42
- def expand_acl (self , acl_name : str ) -> List [str ]:
43
- if acl_name in self .included :
44
- logger .info ('already included: {}' .format (acl_name ))
45
+ def expand_acl (self , acl_file_path : Path ) -> List [str ]:
46
+ if acl_file_path in self .included :
47
+ logger .info ('already included: {}' .format (acl_file_path . name ))
45
48
return []
46
- self .included .add (acl_name )
47
- logger .info ('include: {}' .format (acl_name ))
48
- acl_path = self .find_acl (acl_name )
49
- if not acl_path :
50
- logger .warning ('cannot find: {}' .format (acl_name ))
51
- raise FileNotFoundError ()
49
+ self .included .add (acl_file_path )
50
+ logger .info ('include: {}' .format (acl_file_path .name ))
52
51
53
- acl_source = open (str (acl_path )).read ()
52
+ acl_source = open (str (acl_file_path )).read ()
54
53
55
54
result = [] # type: List[str]
56
55
for line in acl_source .splitlines ():
@@ -59,7 +58,14 @@ def expand_acl(self, acl_name: str) -> List[str]:
59
58
60
59
m = self .atcoder_include .match (line )
61
60
if m :
62
- result .extend (self .expand_acl (m .group (1 )))
61
+ name = m .group (1 )
62
+ result .extend (self .expand_acl (self .find_acl (name )))
63
+ continue
64
+
65
+ m = self .local_include .match (line )
66
+ if m :
67
+ name = m .group (1 )
68
+ result .extend (self .expand_acl (acl_file_path .parent / name ))
63
69
continue
64
70
65
71
result .append (line )
@@ -71,10 +77,11 @@ def expand(self, source: str) -> str:
71
77
result = [] # type: List[str]
72
78
for line in source .splitlines ():
73
79
m = self .atcoder_include .match (line )
74
-
75
80
if m :
76
- result .extend (self .expand_acl (m .group (1 )))
81
+ acl_path = self .find_acl (m .group (1 ))
82
+ result .extend (self .expand_acl (acl_path ))
77
83
continue
84
+
78
85
result .append (line )
79
86
return '\n ' .join (result )
80
87
0 commit comments