diff --git a/src/io/stdin.rs b/src/io/stdin.rs index bd4c1118d..9595973e9 100644 --- a/src/io/stdin.rs +++ b/src/io/stdin.rs @@ -1,9 +1,10 @@ +use futures::lock::Mutex; use std::io; use std::pin::Pin; -use std::sync::Mutex; +use std::sync::Arc; +use std::sync::Mutex as StdMutex; use cfg_if::cfg_if; -use futures::future; use futures::io::{AsyncRead, Initializer}; use crate::future::Future; @@ -29,11 +30,10 @@ use crate::task::{blocking, Context, Poll}; /// # Ok(()) }) } /// ``` pub fn stdin() -> Stdin { - Stdin(Mutex::new(State::Idle(Some(Inner { + Stdin(Mutex::new(Arc::new(StdMutex::new(Inner { stdin: io::stdin(), - line: String::new(), - buf: Vec::new(), - last_op: None, + line: Default::default(), + buf: Default::default(), })))) } @@ -46,21 +46,7 @@ pub fn stdin() -> Stdin { /// [`stdin`]: fn.stdin.html /// [`std::io::Stdin`]: https://doc.rust-lang.org/std/io/struct.Stdin.html #[derive(Debug)] -pub struct Stdin(Mutex); - -/// The state of the asynchronous stdin. -/// -/// The stdin can be either idle or busy performing an asynchronous operation. -#[derive(Debug)] -enum State { - /// The stdin is idle. - Idle(Option), - - /// The stdin is blocked on an asynchronous operation. - /// - /// Awaiting this operation will result in the new state of the stdin. - Busy(blocking::JoinHandle), -} +pub struct Stdin(Mutex>>); /// Inner representation of the asynchronous stdin. #[derive(Debug)] @@ -73,16 +59,6 @@ struct Inner { /// The write buffer. buf: Vec, - - /// The result of the last asynchronous operation on the stdin. - last_op: Option, -} - -/// Possible results of an asynchronous operation on the stdin. -#[derive(Debug)] -enum Operation { - ReadLine(io::Result), - Read(io::Result), } impl Stdin { @@ -102,89 +78,67 @@ impl Stdin { /// # Ok(()) }) } /// ``` pub async fn read_line(&self, buf: &mut String) -> io::Result { - future::poll_fn(|cx| { - let state = &mut *self.0.lock().unwrap(); - - loop { - match state { - State::Idle(opt) => { - let inner = opt.as_mut().unwrap(); - - // Check if the operation has completed. - if let Some(Operation::ReadLine(res)) = inner.last_op.take() { - let n = res?; - - // Copy the read data into the buffer and return. - buf.push_str(&inner.line); - return Poll::Ready(Ok(n)); - } else { - let mut inner = opt.take().unwrap(); - - // Start the operation asynchronously. - *state = State::Busy(blocking::spawn(async move { - inner.line.clear(); - let res = inner.stdin.read_line(&mut inner.line); - inner.last_op = Some(Operation::ReadLine(res)); - State::Idle(Some(inner)) - })); - } - } - // Poll the asynchronous operation the stdin is currently blocked on. - State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), - } - } - }) - .await + let future_lock = self.0.lock().await; + + let mutex = future_lock.clone(); + // Start the operation asynchronously. + let handle = blocking::spawn(async move { + let mut guard = mutex.lock().unwrap(); + let inner: &mut Inner = &mut guard; + + inner.line.clear(); + inner.stdin.read_line(&mut inner.line) + }); + + let res = handle.await; + let n = res?; + + let mutex = future_lock.clone(); + let inner = mutex.lock().unwrap(); + + // Copy the read data into the buffer and return. + buf.push_str(&inner.line); + + Ok(n) } } impl AsyncRead for Stdin { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - let state = &mut *self.0.lock().unwrap(); - - loop { - match state { - State::Idle(opt) => { - let inner = opt.as_mut().unwrap(); - - // Check if the operation has completed. - if let Some(Operation::Read(res)) = inner.last_op.take() { - let n = res?; - - // If more data was read than fits into the buffer, let's retry the read - // operation. - if n <= buf.len() { - // Copy the read data into the buffer and return. - buf[..n].copy_from_slice(&inner.buf[..n]); - return Poll::Ready(Ok(n)); - } - } else { - let mut inner = opt.take().unwrap(); - - // Set the length of the inner buffer to the length of the provided buffer. - if inner.buf.len() < buf.len() { - inner.buf.reserve(buf.len() - inner.buf.len()); - } - unsafe { - inner.buf.set_len(buf.len()); - } - - // Start the operation asynchronously. - *state = State::Busy(blocking::spawn(async move { - let res = io::Read::read(&mut inner.stdin, &mut inner.buf); - inner.last_op = Some(Operation::Read(res)); - State::Idle(Some(inner)) - })); - } - } - // Poll the asynchronous operation the stdin is currently blocked on. - State::Busy(task) => *state = futures::ready!(Pin::new(task).poll(cx)), + let len = buf.len(); + + let future_lock = self.0.lock(); + pin_utils::pin_mut!(future_lock); + let future_lock = futures::ready!(future_lock.poll(cx)); + + let mutex = future_lock.clone(); + let handle = blocking::spawn(async move { + let mut guard = mutex.lock().unwrap(); + let inner: &mut Inner = &mut guard; + + // Set the length of the inner buffer to the length of the provided buffer. + if inner.buf.len() < len { + inner.buf.reserve(len - inner.buf.len()); } - } + unsafe { + inner.buf.set_len(len); + } + + io::Read::read(&mut inner.stdin, &mut inner.buf) + }); + pin_utils::pin_mut!(handle); + handle.poll(cx).map_ok(|n| { + let mutex = future_lock.clone(); + let inner = mutex.lock().unwrap(); + + // Copy the read data into the buffer and return. + buf[..n].copy_from_slice(&inner.buf[..n]); + n + }) } #[inline]