Skip to content

Commit 8ebb888

Browse files
refactor: move find files operations to standalone files
1 parent 19e3baa commit 8ebb888

File tree

2 files changed

+253
-241
lines changed

2 files changed

+253
-241
lines changed

crates/napi/src/find_files.rs

+246
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
use ast_grep_config::RuleCore;
2+
use ast_grep_core::pinned::{NodeData, PinnedNodeData};
3+
use ast_grep_core::{AstGrep, NodeMatch};
4+
use ast_grep_language::SupportLang;
5+
use ignore::{WalkBuilder, WalkParallel, WalkState};
6+
use napi::anyhow::{anyhow, Context, Result as Ret};
7+
use napi::bindgen_prelude::*;
8+
use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode};
9+
use napi::{JsNumber, Task};
10+
use napi_derive::napi;
11+
use std::collections::HashMap;
12+
use std::sync::atomic::{AtomicU32, Ordering};
13+
use std::sync::mpsc::channel;
14+
15+
use crate::doc::{JsDoc, NapiConfig};
16+
use crate::fe_lang::{build_files, FrontEndLanguage, LangOption};
17+
use crate::sg_node::{SgNode, SgRoot};
18+
19+
pub struct ParseAsync {
20+
pub src: String,
21+
pub lang: FrontEndLanguage,
22+
}
23+
24+
impl Task for ParseAsync {
25+
type Output = SgRoot;
26+
type JsValue = SgRoot;
27+
28+
fn compute(&mut self) -> Result<Self::Output> {
29+
let src = std::mem::take(&mut self.src);
30+
let doc = JsDoc::new(src, self.lang.into());
31+
Ok(SgRoot(AstGrep::doc(doc), "anonymous".into()))
32+
}
33+
fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> {
34+
Ok(output)
35+
}
36+
}
37+
38+
type Entry = std::result::Result<ignore::DirEntry, ignore::Error>;
39+
40+
pub struct IterateFiles<D> {
41+
walk: WalkParallel,
42+
lang_option: LangOption,
43+
tsfn: D,
44+
producer: fn(&D, Entry, &LangOption) -> Ret<bool>,
45+
}
46+
47+
impl<T: 'static + Send + Sync> Task for IterateFiles<T> {
48+
type Output = u32;
49+
type JsValue = JsNumber;
50+
51+
fn compute(&mut self) -> Result<Self::Output> {
52+
let tsfn = &self.tsfn;
53+
let file_count = AtomicU32::new(0);
54+
let (tx, rx) = channel();
55+
let producer = self.producer;
56+
let walker = std::mem::replace(&mut self.walk, WalkBuilder::new(".").build_parallel());
57+
walker.run(|| {
58+
let tx = tx.clone();
59+
let file_count = &file_count;
60+
let lang_option = &self.lang_option;
61+
Box::new(move |entry| match producer(tsfn, entry, lang_option) {
62+
Ok(true) => {
63+
// file is sent to JS thread, increment file count
64+
if tx.send(()).is_ok() {
65+
file_count.fetch_add(1, Ordering::AcqRel);
66+
WalkState::Continue
67+
} else {
68+
WalkState::Quit
69+
}
70+
}
71+
Ok(false) => WalkState::Continue,
72+
Err(_) => WalkState::Skip,
73+
})
74+
});
75+
// Drop the last sender to stop `rx` waiting for message.
76+
// The program will not complete if we comment this out.
77+
drop(tx);
78+
while rx.recv().is_ok() {
79+
// pass
80+
}
81+
Ok(file_count.load(Ordering::Acquire))
82+
}
83+
fn resolve(&mut self, env: Env, output: Self::Output) -> Result<Self::JsValue> {
84+
env.create_uint32(output)
85+
}
86+
}
87+
88+
// See https://github.com/ast-grep/ast-grep/issues/206
89+
// NodeJS has a 1000 file limitation on sync iteration count.
90+
// https://github.com/nodejs/node/blob/8ba54e50496a6a5c21d93133df60a9f7cb6c46ce/src/node_api.cc#L336
91+
const THREAD_FUNC_QUEUE_SIZE: usize = 1000;
92+
93+
type ParseFiles = IterateFiles<ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>>;
94+
95+
#[napi(object)]
96+
pub struct FileOption {
97+
pub paths: Vec<String>,
98+
pub language_globs: HashMap<String, Vec<String>>,
99+
}
100+
101+
#[napi(ts_return_type = "Promise<number>")]
102+
pub fn parse_files(
103+
paths: Either<Vec<String>, FileOption>,
104+
#[napi(ts_arg_type = "(err: null | Error, result: SgRoot) => void")] callback: JsFunction,
105+
) -> Result<AsyncTask<ParseFiles>> {
106+
let tsfn: ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled> =
107+
callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| Ok(vec![ctx.value]))?;
108+
let (paths, globs) = match paths {
109+
Either::A(v) => (v, HashMap::new()),
110+
Either::B(FileOption {
111+
paths,
112+
language_globs,
113+
}) => (paths, FrontEndLanguage::lang_globs(language_globs)),
114+
};
115+
let walk = build_files(paths, &globs)?;
116+
Ok(AsyncTask::new(ParseFiles {
117+
walk,
118+
tsfn,
119+
lang_option: LangOption::infer(&globs),
120+
producer: call_sg_root,
121+
}))
122+
}
123+
124+
// returns if the entry is a file and sent to JavaScript queue
125+
fn call_sg_root(
126+
tsfn: &ThreadsafeFunction<SgRoot, ErrorStrategy::CalleeHandled>,
127+
entry: std::result::Result<ignore::DirEntry, ignore::Error>,
128+
lang_option: &LangOption,
129+
) -> Ret<bool> {
130+
let entry = entry?;
131+
if !entry
132+
.file_type()
133+
.context("could not use stdin as file")?
134+
.is_file()
135+
{
136+
return Ok(false);
137+
}
138+
let (root, path) = get_root(entry, lang_option)?;
139+
let sg = SgRoot(root, path);
140+
tsfn.call(Ok(sg), ThreadsafeFunctionCallMode::Blocking);
141+
Ok(true)
142+
}
143+
144+
fn get_root(entry: ignore::DirEntry, lang_option: &LangOption) -> Ret<(AstGrep<JsDoc>, String)> {
145+
let path = entry.into_path();
146+
let file_content = std::fs::read_to_string(&path)?;
147+
let lang = lang_option
148+
.get_lang(&path)
149+
.context(anyhow!("file not recognized"))?;
150+
let doc = JsDoc::new(file_content, lang);
151+
Ok((AstGrep::doc(doc), path.to_string_lossy().into()))
152+
}
153+
154+
pub type FindInFiles = IterateFiles<(
155+
ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
156+
RuleCore<SupportLang>,
157+
)>;
158+
159+
pub struct PinnedNodes(
160+
PinnedNodeData<JsDoc, Vec<NodeMatch<'static, JsDoc>>>,
161+
String,
162+
);
163+
unsafe impl Send for PinnedNodes {}
164+
unsafe impl Sync for PinnedNodes {}
165+
166+
#[napi(object)]
167+
pub struct FindConfig {
168+
/// specify the file paths to recursively find files
169+
pub paths: Vec<String>,
170+
/// a Rule object to find what nodes will match
171+
pub matcher: NapiConfig,
172+
/// An list of pattern globs to treat of certain files in the specified language.
173+
/// eg. ['*.vue', '*.svelte'] for html.findFiles, or ['*.ts'] for tsx.findFiles.
174+
/// It is slightly different from https://ast-grep.github.io/reference/sgconfig.html#languageglobs
175+
pub language_globs: Option<Vec<String>>,
176+
}
177+
178+
pub fn find_in_files_impl(
179+
lang: FrontEndLanguage,
180+
config: FindConfig,
181+
callback: JsFunction,
182+
) -> Result<AsyncTask<FindInFiles>> {
183+
let tsfn = callback.create_threadsafe_function(THREAD_FUNC_QUEUE_SIZE, |ctx| {
184+
from_pinned_data(ctx.value, ctx.env)
185+
})?;
186+
let FindConfig {
187+
paths,
188+
matcher,
189+
language_globs,
190+
} = config;
191+
let rule = matcher.parse_with(lang)?;
192+
let walk = lang.find_files(paths, language_globs)?;
193+
Ok(AsyncTask::new(FindInFiles {
194+
walk,
195+
tsfn: (tsfn, rule),
196+
lang_option: LangOption::Specified(lang),
197+
producer: call_sg_node,
198+
}))
199+
}
200+
201+
// TODO: optimize
202+
fn from_pinned_data(pinned: PinnedNodes, env: napi::Env) -> Result<Vec<Vec<SgNode>>> {
203+
let (root, nodes) = pinned.0.into_raw();
204+
let sg_root = SgRoot(AstGrep { inner: root }, pinned.1);
205+
let reference = SgRoot::into_reference(sg_root, env)?;
206+
let mut v = vec![];
207+
for mut node in nodes {
208+
let root_ref = reference.clone(env)?;
209+
let sg_node = SgNode {
210+
inner: root_ref.share_with(env, |root| {
211+
let r = &root.0.inner;
212+
node.visit_nodes(|n| unsafe { r.readopt(n) });
213+
Ok(node)
214+
})?,
215+
};
216+
v.push(sg_node);
217+
}
218+
Ok(vec![v])
219+
}
220+
221+
fn call_sg_node(
222+
(tsfn, rule): &(
223+
ThreadsafeFunction<PinnedNodes, ErrorStrategy::CalleeHandled>,
224+
RuleCore<SupportLang>,
225+
),
226+
entry: std::result::Result<ignore::DirEntry, ignore::Error>,
227+
lang_option: &LangOption,
228+
) -> Ret<bool> {
229+
let entry = entry?;
230+
if !entry
231+
.file_type()
232+
.context("could not use stdin as file")?
233+
.is_file()
234+
{
235+
return Ok(false);
236+
}
237+
let (root, path) = get_root(entry, lang_option)?;
238+
let mut pinned = PinnedNodeData::new(root.inner, |r| r.root().find_all(rule).collect());
239+
let hits: &Vec<_> = pinned.get_data();
240+
if hits.is_empty() {
241+
return Ok(false);
242+
}
243+
let pinned = PinnedNodes(pinned, path);
244+
tsfn.call(Ok(pinned), ThreadsafeFunctionCallMode::Blocking);
245+
Ok(true)
246+
}

0 commit comments

Comments
 (0)