6
6
from logging import Logger , basicConfig , getLogger
7
7
from os import getenv , environ
8
8
from pathlib import Path
9
- from typing import List
9
+ from typing import List , Set , Optional
10
10
11
11
12
12
logger = getLogger (__name__ ) # type: Logger
13
13
14
- atcoder_include = re .compile ('#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
15
14
16
- include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
17
-
18
- lib_path = Path .cwd ()
19
-
20
- defined = set ()
21
-
22
- def dfs (f : str ) -> List [str ]:
23
- global defined
24
- if f in defined :
25
- logger .info ('already included {}, skip' .format (f ))
26
- return []
27
- defined .add (f )
28
-
29
- logger .info ('include {}' .format (f ))
30
-
31
- s = open (str (lib_path / f )).read ()
32
- result = []
33
- for line in s .splitlines ():
34
- if include_guard .match (line ):
35
- continue
36
-
37
- m = atcoder_include .match (line )
38
- if m :
39
- result .extend (dfs (m .group (1 )))
40
- continue
41
- result .append (line )
42
- return result
15
+ class Expander :
16
+ atcoder_include = re .compile (
17
+ '#include\s*["<](atcoder/[a-z_]*(|.hpp))[">]\s*' )
18
+
19
+ include_guard = re .compile ('#.*ATCODER_[A-Z_]*_HPP' )
20
+
21
+ def __init__ (self , lib_paths : List [Path ] = None ):
22
+ if lib_paths :
23
+ self .lib_paths = lib_paths
24
+ else :
25
+ self .lib_paths = [Path .cwd ()]
26
+
27
+ included = set () # type: Set[str]
28
+
29
+ def find_acl (self , acl_name : str ) -> Optional [Path ]:
30
+ for lib_path in self .lib_paths :
31
+ path = lib_path / acl_name
32
+ if path .exists ():
33
+ return path
34
+ return None
35
+
36
+ def expand_acl (self , acl_name : str ) -> List [str ]:
37
+ if acl_name in self .included :
38
+ logger .info ('already included: {}' .format (acl_name ))
39
+ return []
40
+ self .included .add (acl_name )
41
+ logger .info ('include: {}' .format (acl_name ))
42
+ acl_path = self .find_acl (acl_name )
43
+ if not acl_path :
44
+ logger .warning ('cannot find: {}' .format (acl_name ))
45
+ raise FileNotFoundError ()
46
+
47
+ acl_source = open (str (acl_path )).read ()
48
+
49
+ result = [] # type: List[str]
50
+ for line in acl_source .splitlines ():
51
+ if self .include_guard .match (line ):
52
+ continue
53
+
54
+ m = self .atcoder_include .match (line )
55
+ if m :
56
+ result .extend (self .expand_acl (m .group (1 )))
57
+ continue
58
+ result .append (line )
59
+ return result
60
+
61
+ def expand (self , source : str ) -> str :
62
+ self .included = set ()
63
+ result = [] # type: List[str]
64
+ for line in source .splitlines ():
65
+ m = self .atcoder_include .match (line )
66
+
67
+ if m :
68
+ result .extend (self .expand_acl (m .group (1 )))
69
+ continue
70
+ result .append (line )
71
+ return '\n ' .join (result )
43
72
44
73
45
74
if __name__ == "__main__" :
@@ -55,22 +84,16 @@ def dfs(f: str) -> List[str]:
55
84
parser .add_argument ('--lib' , help = 'Path to Atcoder Library' )
56
85
opts = parser .parse_args ()
57
86
87
+ lib_path = Path .cwd ()
58
88
if opts .lib :
59
89
lib_path = Path (opts .lib )
60
90
elif 'CPLUS_INCLUDE_PATH' in environ :
61
91
lib_path = Path (environ ['CPLUS_INCLUDE_PATH' ])
62
- s = open (opts .source ).read ()
63
-
64
- result = []
65
- for line in s .splitlines ():
66
- m = atcoder_include .match (line )
67
92
68
- if m :
69
- result .extend (dfs (m .group (1 )))
70
- continue
71
- result .append (line )
93
+ expander = Expander ([lib_path ])
94
+ source = open (opts .source ).read ()
95
+ output = expander .expand (source )
72
96
73
- output = '\n ' .join (result ) + '\n '
74
97
if opts .console :
75
98
print (output )
76
99
else :
0 commit comments