Skip to content

Use futures::lock::Mutex in Stdin #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 59 additions & 105 deletions src/io/stdin.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use futures::lock::Mutex;
use std::io;
use std::pin::Pin;
use std::sync::Mutex;
use std::sync::Arc;
use std::sync::Mutex as StdMutex;

use cfg_if::cfg_if;
use futures::future;
use futures::io::{AsyncRead, Initializer};

use crate::future::Future;
Expand All @@ -29,11 +30,10 @@ use crate::task::{blocking, Context, Poll};
/// # Ok(()) }) }
/// ```
pub fn stdin() -> Stdin {
Stdin(Mutex::new(State::Idle(Some(Inner {
Stdin(Mutex::new(Arc::new(StdMutex::new(Inner {
stdin: io::stdin(),
line: String::new(),
buf: Vec::new(),
last_op: None,
line: Default::default(),
buf: Default::default(),
}))))
}

Expand All @@ -46,21 +46,7 @@ pub fn stdin() -> Stdin {
/// [`stdin`]: fn.stdin.html
/// [`std::io::Stdin`]: https://doc.rust-lang.org/std/io/struct.Stdin.html
#[derive(Debug)]
pub struct Stdin(Mutex<State>);

/// The state of the asynchronous stdin.
///
/// The stdin can be either idle or busy performing an asynchronous operation.
#[derive(Debug)]
enum State {
/// The stdin is idle.
Idle(Option<Inner>),

/// The stdin is blocked on an asynchronous operation.
///
/// Awaiting this operation will result in the new state of the stdin.
Busy(blocking::JoinHandle<State>),
}
pub struct Stdin(Mutex<Arc<StdMutex<Inner>>>);

/// Inner representation of the asynchronous stdin.
#[derive(Debug)]
Expand All @@ -73,16 +59,6 @@ struct Inner {

/// The write buffer.
buf: Vec<u8>,

/// The result of the last asynchronous operation on the stdin.
last_op: Option<Operation>,
}

/// Possible results of an asynchronous operation on the stdin.
#[derive(Debug)]
enum Operation {
ReadLine(io::Result<usize>),
Read(io::Result<usize>),
}

impl Stdin {
Expand All @@ -102,89 +78,67 @@ impl Stdin {
/// # Ok(()) }) }
/// ```
pub async fn read_line(&self, buf: &mut String) -> io::Result<usize> {
future::poll_fn(|cx| {
let state = &mut *self.0.lock().unwrap();

loop {
match state {
State::Idle(opt) => {
let inner = opt.as_mut().unwrap();

// Check if the operation has completed.
if let Some(Operation::ReadLine(res)) = inner.last_op.take() {
let n = res?;

// Copy the read data into the buffer and return.
buf.push_str(&inner.line);
return Poll::Ready(Ok(n));
} else {
let mut inner = opt.take().unwrap();

// Start the operation asynchronously.
*state = State::Busy(blocking::spawn(async move {
inner.line.clear();
let res = inner.stdin.read_line(&mut inner.line);
inner.last_op = Some(Operation::ReadLine(res));
State::Idle(Some(inner))
}));
}
}
// Poll the asynchronous operation the stdin is currently blocked on.
State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)),
}
}
})
.await
let future_lock = self.0.lock().await;

let mutex = future_lock.clone();
// Start the operation asynchronously.
let handle = blocking::spawn(async move {
let mut guard = mutex.lock().unwrap();
let inner: &mut Inner = &mut guard;

inner.line.clear();
inner.stdin.read_line(&mut inner.line)
});

let res = handle.await;
let n = res?;

let mutex = future_lock.clone();
let inner = mutex.lock().unwrap();

// Copy the read data into the buffer and return.
buf.push_str(&inner.line);

Ok(n)
}
}

impl AsyncRead for Stdin {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let state = &mut *self.0.lock().unwrap();

loop {
match state {
State::Idle(opt) => {
let inner = opt.as_mut().unwrap();

// Check if the operation has completed.
if let Some(Operation::Read(res)) = inner.last_op.take() {
let n = res?;

// If more data was read than fits into the buffer, let's retry the read
// operation.
if n <= buf.len() {
// Copy the read data into the buffer and return.
buf[..n].copy_from_slice(&inner.buf[..n]);
return Poll::Ready(Ok(n));
}
} else {
let mut inner = opt.take().unwrap();

// Set the length of the inner buffer to the length of the provided buffer.
if inner.buf.len() < buf.len() {
inner.buf.reserve(buf.len() - inner.buf.len());
}
unsafe {
inner.buf.set_len(buf.len());
}

// Start the operation asynchronously.
*state = State::Busy(blocking::spawn(async move {
let res = io::Read::read(&mut inner.stdin, &mut inner.buf);
inner.last_op = Some(Operation::Read(res));
State::Idle(Some(inner))
}));
}
}
// Poll the asynchronous operation the stdin is currently blocked on.
State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)),
let len = buf.len();

let future_lock = self.0.lock();
pin_utils::pin_mut!(future_lock);
let future_lock = futures::ready!(future_lock.poll(cx));

let mutex = future_lock.clone();
let handle = blocking::spawn(async move {
let mut guard = mutex.lock().unwrap();
let inner: &mut Inner = &mut guard;

// Set the length of the inner buffer to the length of the provided buffer.
if inner.buf.len() < len {
inner.buf.reserve(len - inner.buf.len());
}
}
unsafe {
inner.buf.set_len(len);
}

io::Read::read(&mut inner.stdin, &mut inner.buf)
});
pin_utils::pin_mut!(handle);
handle.poll(cx).map_ok(|n| {
let mutex = future_lock.clone();
let inner = mutex.lock().unwrap();

// Copy the read data into the buffer and return.
buf[..n].copy_from_slice(&inner.buf[..n]);
n
})
}

#[inline]
Expand Down