Skip to content

Commit 2a7c1b6

Browse files
authored
feat: allow pluggable tower layers in connector service stack (#2496)
Co-authored-by: Jess Izen <[email protected]>
1 parent 8a2174f commit 2a7c1b6

File tree

10 files changed

+1136
-131
lines changed

10 files changed

+1136
-131
lines changed

Cargo.toml

+11-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = ["Sean McArthur <[email protected]>"]
1010
readme = "README.md"
1111
license = "MIT OR Apache-2.0"
1212
edition = "2021"
13-
rust-version = "1.63.0"
13+
rust-version = "1.64.0"
1414
autotests = true
1515

1616
[package.metadata.docs.rs]
@@ -105,6 +105,7 @@ url = "2.4"
105105
bytes = "1.0"
106106
serde = "1.0"
107107
serde_urlencoded = "0.7.1"
108+
tower = { version = "0.5.2", default-features = false, features = ["timeout", "util"] }
108109
tower-service = "0.3"
109110
futures-core = { version = "0.3.28", default-features = false }
110111
futures-util = { version = "0.3.28", default-features = false }
@@ -169,7 +170,6 @@ quinn = { version = "0.11.1", default-features = false, features = ["rustls", "r
169170
slab = { version = "0.4.9", optional = true } # just to get minimal versions working with quinn
170171
futures-channel = { version = "0.3", optional = true }
171172

172-
173173
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
174174
env_logger = "0.10"
175175
hyper = { version = "1.1.0", default-features = false, features = ["http1", "http2", "client", "server"] }
@@ -222,6 +222,11 @@ features = [
222222
wasm-bindgen = { version = "0.2.89", features = ["serde-serialize"] }
223223
wasm-bindgen-test = "0.3"
224224

225+
[dev-dependencies]
226+
tower = { version = "0.5.2", default-features = false, features = ["limit"] }
227+
num_cpus = "1.0"
228+
libc = "0"
229+
225230
[lints.rust]
226231
unexpected_cfgs = { level = "warn", check-cfg = ['cfg(reqwest_unstable)'] }
227232

@@ -253,6 +258,10 @@ path = "examples/form.rs"
253258
name = "simple"
254259
path = "examples/simple.rs"
255260

261+
[[example]]
262+
name = "connect_via_lower_priority_tokio_runtime"
263+
path = "examples/connect_via_lower_priority_tokio_runtime.rs"
264+
256265
[[test]]
257266
name = "blocking"
258267
path = "tests/blocking.rs"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#![deny(warnings)]
2+
// This example demonstrates how to delegate the connect calls, which contain TLS handshakes,
3+
// to a secondary tokio runtime of lower OS thread priority using a custom tower layer.
4+
// This helps to ensure that long-running futures during handshake crypto operations don't block other I/O futures.
5+
//
6+
// This does introduce overhead of additional threads, channels, extra vtables, etc,
7+
// so it is best suited to services with large numbers of incoming connections or that
8+
// are otherwise very sensitive to any blocking futures. Or, you might want fewer threads
9+
// and/or to use the current_thread runtime.
10+
//
11+
// This is using the `tokio` runtime and certain other dependencies:
12+
//
13+
// `tokio = { version = "1", features = ["full"] }`
14+
// `num_cpus = "1.0"`
15+
// `libc = "0"`
16+
// `pin-project-lite = "0.2"`
17+
// `tower = { version = "0.5", default-features = false}`
18+
19+
#[cfg(not(target_arch = "wasm32"))]
20+
#[tokio::main]
21+
async fn main() -> Result<(), reqwest::Error> {
22+
background_threadpool::init_background_runtime();
23+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
24+
25+
let client = reqwest::Client::builder()
26+
.connector_layer(background_threadpool::BackgroundProcessorLayer::new())
27+
.build()
28+
.expect("should be able to build reqwest client");
29+
30+
let url = if let Some(url) = std::env::args().nth(1) {
31+
url
32+
} else {
33+
println!("No CLI URL provided, using default.");
34+
"https://hyper.rs".into()
35+
};
36+
37+
eprintln!("Fetching {url:?}...");
38+
39+
let res = client.get(url).send().await?;
40+
41+
eprintln!("Response: {:?} {}", res.version(), res.status());
42+
eprintln!("Headers: {:#?}\n", res.headers());
43+
44+
let body = res.text().await?;
45+
46+
println!("{body}");
47+
48+
Ok(())
49+
}
50+
51+
// separating out for convenience to avoid a million #[cfg(not(target_arch = "wasm32"))]
52+
#[cfg(not(target_arch = "wasm32"))]
53+
mod background_threadpool {
54+
use std::{
55+
future::Future,
56+
pin::Pin,
57+
sync::OnceLock,
58+
task::{Context, Poll},
59+
};
60+
61+
use futures_util::TryFutureExt;
62+
use pin_project_lite::pin_project;
63+
use tokio::{runtime::Handle, select, sync::mpsc::error::TrySendError};
64+
use tower::{BoxError, Layer, Service};
65+
66+
static CPU_HEAVY_THREAD_POOL: OnceLock<
67+
tokio::sync::mpsc::Sender<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
68+
> = OnceLock::new();
69+
70+
pub(crate) fn init_background_runtime() {
71+
std::thread::Builder::new()
72+
.name("cpu-heavy-background-threadpool".to_string())
73+
.spawn(move || {
74+
let rt = tokio::runtime::Builder::new_multi_thread()
75+
.thread_name("cpu-heavy-background-pool-thread")
76+
.worker_threads(num_cpus::get() as usize)
77+
// ref: https://github.com/tokio-rs/tokio/issues/4941
78+
// consider uncommenting if seeing heavy task contention
79+
// .disable_lifo_slot()
80+
.on_thread_start(move || {
81+
#[cfg(target_os = "linux")]
82+
unsafe {
83+
// Increase thread pool thread niceness, so they are lower priority
84+
// than the foreground executor and don't interfere with I/O tasks
85+
{
86+
*libc::__errno_location() = 0;
87+
if libc::nice(10) == -1 && *libc::__errno_location() != 0 {
88+
let error = std::io::Error::last_os_error();
89+
log::error!("failed to set threadpool niceness: {}", error);
90+
}
91+
}
92+
}
93+
})
94+
.enable_all()
95+
.build()
96+
.unwrap_or_else(|e| panic!("cpu heavy runtime failed_to_initialize: {}", e));
97+
rt.block_on(async {
98+
log::debug!("starting background cpu-heavy work");
99+
process_cpu_work().await;
100+
});
101+
})
102+
.unwrap_or_else(|e| panic!("cpu heavy thread failed_to_initialize: {}", e));
103+
}
104+
105+
#[cfg(not(target_arch = "wasm32"))]
106+
async fn process_cpu_work() {
107+
// we only use this channel for routing work, it should move pretty quick, it can be small
108+
let (tx, mut rx) = tokio::sync::mpsc::channel(10);
109+
// share the handle to the background channel globally
110+
CPU_HEAVY_THREAD_POOL.set(tx).unwrap();
111+
112+
while let Some(work) = rx.recv().await {
113+
tokio::task::spawn(work);
114+
}
115+
}
116+
117+
// retrieve the sender to the background channel, and send the future over to it for execution
118+
fn send_to_background_runtime(future: impl Future<Output = ()> + Send + 'static) {
119+
let tx = CPU_HEAVY_THREAD_POOL.get().expect(
120+
"start up the secondary tokio runtime before sending to `CPU_HEAVY_THREAD_POOL`",
121+
);
122+
123+
match tx.try_send(Box::pin(future)) {
124+
Ok(_) => (),
125+
Err(TrySendError::Closed(_)) => {
126+
panic!("background cpu heavy runtime channel is closed")
127+
}
128+
Err(TrySendError::Full(msg)) => {
129+
log::warn!(
130+
"background cpu heavy runtime channel is full, task spawning loop delayed"
131+
);
132+
let tx = tx.clone();
133+
Handle::current().spawn(async move {
134+
tx.send(msg)
135+
.await
136+
.expect("background cpu heavy runtime channel is closed")
137+
});
138+
}
139+
}
140+
}
141+
142+
// This tower layer injects futures with a oneshot channel, and then sends them to the background runtime for processing.
143+
// We don't use the Buffer service because that is intended to process sequentially on a single task, whereas we want to
144+
// spawn a new task per call.
145+
#[derive(Copy, Clone)]
146+
pub struct BackgroundProcessorLayer {}
147+
impl BackgroundProcessorLayer {
148+
pub fn new() -> Self {
149+
Self {}
150+
}
151+
}
152+
impl<S> Layer<S> for BackgroundProcessorLayer {
153+
type Service = BackgroundProcessor<S>;
154+
fn layer(&self, service: S) -> Self::Service {
155+
BackgroundProcessor::new(service)
156+
}
157+
}
158+
159+
impl std::fmt::Debug for BackgroundProcessorLayer {
160+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
161+
f.debug_struct("BackgroundProcessorLayer").finish()
162+
}
163+
}
164+
165+
// This tower service injects futures with a oneshot channel, and then sends them to the background runtime for processing.
166+
#[derive(Debug, Clone)]
167+
pub struct BackgroundProcessor<S> {
168+
inner: S,
169+
}
170+
171+
impl<S> BackgroundProcessor<S> {
172+
pub fn new(inner: S) -> Self {
173+
BackgroundProcessor { inner }
174+
}
175+
}
176+
177+
impl<S, Request> Service<Request> for BackgroundProcessor<S>
178+
where
179+
S: Service<Request>,
180+
S::Response: Send + 'static,
181+
S::Error: Into<BoxError> + Send,
182+
S::Future: Send + 'static,
183+
{
184+
type Response = S::Response;
185+
186+
type Error = BoxError;
187+
188+
type Future = BackgroundResponseFuture<S::Response>;
189+
190+
fn poll_ready(
191+
&mut self,
192+
cx: &mut std::task::Context<'_>,
193+
) -> std::task::Poll<Result<(), Self::Error>> {
194+
match self.inner.poll_ready(cx) {
195+
Poll::Pending => Poll::Pending,
196+
Poll::Ready(r) => Poll::Ready(r.map_err(Into::into)),
197+
}
198+
}
199+
200+
fn call(&mut self, req: Request) -> Self::Future {
201+
let response = self.inner.call(req);
202+
203+
// wrap our inner service's future with a future that writes to this oneshot channel
204+
let (mut tx, rx) = tokio::sync::oneshot::channel();
205+
let future = async move {
206+
select!(
207+
_ = tx.closed() => {
208+
// receiver already dropped, don't need to do anything
209+
}
210+
result = response.map_err(|err| Into::<BoxError>::into(err)) => {
211+
// if this fails, the receiver already dropped, so we don't need to do anything
212+
let _ = tx.send(result);
213+
}
214+
)
215+
};
216+
// send the wrapped future to the background
217+
send_to_background_runtime(future);
218+
219+
BackgroundResponseFuture::new(rx)
220+
}
221+
}
222+
223+
// `BackgroundProcessor` response future
224+
pin_project! {
225+
#[derive(Debug)]
226+
pub struct BackgroundResponseFuture<S> {
227+
#[pin]
228+
rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>,
229+
}
230+
}
231+
232+
impl<S> BackgroundResponseFuture<S> {
233+
pub(crate) fn new(rx: tokio::sync::oneshot::Receiver<Result<S, BoxError>>) -> Self {
234+
BackgroundResponseFuture { rx }
235+
}
236+
}
237+
238+
impl<S> Future for BackgroundResponseFuture<S>
239+
where
240+
S: Send + 'static,
241+
{
242+
type Output = Result<S, BoxError>;
243+
244+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
245+
let this = self.project();
246+
247+
// now poll on the receiver end of the oneshot to get the result
248+
match this.rx.poll(cx) {
249+
Poll::Ready(v) => match v {
250+
Ok(v) => Poll::Ready(v.map_err(Into::into)),
251+
Err(err) => Poll::Ready(Err(Box::new(err) as BoxError)),
252+
},
253+
Poll::Pending => Poll::Pending,
254+
}
255+
}
256+
}
257+
}
258+
259+
// The [cfg(not(target_arch = "wasm32"))] above prevent building the tokio::main function
260+
// for wasm32 target, because tokio isn't compatible with wasm32.
261+
// If you aren't building for wasm32, you don't need that line.
262+
// The two lines below avoid the "'main' function not found" error when building for wasm32 target.
263+
#[cfg(any(target_arch = "wasm32"))]
264+
fn main() {}

0 commit comments

Comments
 (0)