Skip to content

Commit 54e3fe7

Browse files
committed
Auto merge of #43062 - sfackler:connect-timeout, r=alexcrichton
Implement TcpStream::connect_timeout This breaks the "single syscall rule", but it's really annoying to hand write and is pretty foundational. r? @alexcrichton cc @rust-lang/libs
2 parents b80b659 + 8c92da3 commit 54e3fe7

File tree

6 files changed

+193
-2
lines changed

6 files changed

+193
-2
lines changed

src/libstd/net/tcp.rs

+33
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,24 @@ impl TcpStream {
134134
super::each_addr(addr, net_imp::TcpStream::connect).map(TcpStream)
135135
}
136136

137+
/// Opens a TCP connection to a remote host with a timeout.
138+
///
139+
/// Unlike `connect`, `connect_timeout` takes a single [`SocketAddr`] since
140+
/// timeout must be applied to individual addresses.
141+
///
142+
/// It is an error to pass a zero `Duration` to this function.
143+
///
144+
/// Unlike other methods on `TcpStream`, this does not correspond to a
145+
/// single system call. It instead calls `connect` in nonblocking mode and
146+
/// then uses an OS-specific mechanism to await the completion of the
147+
/// connection request.
148+
///
149+
/// [`SocketAddr`]: ../../std/net/enum.SocketAddr.html
150+
#[unstable(feature = "tcpstream_connect_timeout", issue = "43709")]
151+
pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
152+
net_imp::TcpStream::connect_timeout(addr, timeout).map(TcpStream)
153+
}
154+
137155
/// Returns the socket address of the remote peer of this TCP connection.
138156
///
139157
/// # Examples
@@ -1509,4 +1527,19 @@ mod tests {
15091527
t!(txdone.send(()));
15101528
})
15111529
}
1530+
1531+
#[test]
1532+
fn connect_timeout_unroutable() {
1533+
// this IP is unroutable, so connections should always time out.
1534+
let addr = "10.255.255.1:80".parse().unwrap();
1535+
let e = TcpStream::connect_timeout(&addr, Duration::from_millis(250)).unwrap_err();
1536+
assert_eq!(e.kind(), io::ErrorKind::TimedOut);
1537+
}
1538+
1539+
#[test]
1540+
fn connect_timeout_valid() {
1541+
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1542+
let addr = listener.local_addr().unwrap();
1543+
TcpStream::connect_timeout(&addr, Duration::from_secs(2)).unwrap();
1544+
}
15121545
}

src/libstd/sys/redox/net/tcp.rs

+4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ impl TcpStream {
3232
Ok(TcpStream(File::open(&Path::new(path.as_str()), &options)?))
3333
}
3434

35+
pub fn connect_timeout(_addr: &SocketAddr, _timeout: Duration) -> Result<()> {
36+
Err(Error::new(ErrorKind::Other, "TcpStream::connect_timeout not implemented"))
37+
}
38+
3539
pub fn duplicate(&self) -> Result<TcpStream> {
3640
Ok(TcpStream(self.0.dup(&[])?))
3741
}

src/libstd/sys/unix/net.rs

+66-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use str;
1717
use sys::fd::FileDesc;
1818
use sys_common::{AsInner, FromInner, IntoInner};
1919
use sys_common::net::{getsockopt, setsockopt, sockaddr_to_addr};
20-
use time::Duration;
20+
use time::{Duration, Instant};
21+
use cmp;
2122

2223
pub use sys::{cvt, cvt_r};
2324
pub extern crate libc as netc;
@@ -122,6 +123,70 @@ impl Socket {
122123
}
123124
}
124125

126+
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
127+
self.set_nonblocking(true)?;
128+
let r = unsafe {
129+
let (addrp, len) = addr.into_inner();
130+
cvt(libc::connect(self.0.raw(), addrp, len))
131+
};
132+
self.set_nonblocking(false)?;
133+
134+
match r {
135+
Ok(_) => return Ok(()),
136+
// there's no ErrorKind for EINPROGRESS :(
137+
Err(ref e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
138+
Err(e) => return Err(e),
139+
}
140+
141+
let mut pollfd = libc::pollfd {
142+
fd: self.0.raw(),
143+
events: libc::POLLOUT,
144+
revents: 0,
145+
};
146+
147+
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
148+
return Err(io::Error::new(io::ErrorKind::InvalidInput,
149+
"cannot set a 0 duration timeout"));
150+
}
151+
152+
let start = Instant::now();
153+
154+
loop {
155+
let elapsed = start.elapsed();
156+
if elapsed >= timeout {
157+
return Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out"));
158+
}
159+
160+
let timeout = timeout - elapsed;
161+
let mut timeout = timeout.as_secs()
162+
.saturating_mul(1_000)
163+
.saturating_add(timeout.subsec_nanos() as u64 / 1_000_000);
164+
if timeout == 0 {
165+
timeout = 1;
166+
}
167+
168+
let timeout = cmp::min(timeout, c_int::max_value() as u64) as c_int;
169+
170+
match unsafe { libc::poll(&mut pollfd, 1, timeout) } {
171+
-1 => {
172+
let err = io::Error::last_os_error();
173+
if err.kind() != io::ErrorKind::Interrupted {
174+
return Err(err);
175+
}
176+
}
177+
0 => {}
178+
_ => {
179+
if pollfd.revents & libc::POLLOUT == 0 {
180+
if let Some(e) = self.take_error()? {
181+
return Err(e);
182+
}
183+
}
184+
return Ok(());
185+
}
186+
}
187+
}
188+
}
189+
125190
pub fn accept(&self, storage: *mut sockaddr, len: *mut socklen_t)
126191
-> io::Result<Socket> {
127192
// Unfortunately the only known way right now to accept a socket and

src/libstd/sys/windows/c.rs

+27
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ pub const PIPE_TYPE_BYTE: DWORD = 0x00000000;
298298
pub const PIPE_REJECT_REMOTE_CLIENTS: DWORD = 0x00000008;
299299
pub const PIPE_READMODE_BYTE: DWORD = 0x00000000;
300300

301+
pub const FD_SETSIZE: usize = 64;
302+
301303
#[repr(C)]
302304
#[cfg(target_arch = "x86")]
303305
pub struct WSADATA {
@@ -837,6 +839,26 @@ pub struct CONSOLE_READCONSOLE_CONTROL {
837839
}
838840
pub type PCONSOLE_READCONSOLE_CONTROL = *mut CONSOLE_READCONSOLE_CONTROL;
839841

842+
#[repr(C)]
843+
#[derive(Copy)]
844+
pub struct fd_set {
845+
pub fd_count: c_uint,
846+
pub fd_array: [SOCKET; FD_SETSIZE],
847+
}
848+
849+
impl Clone for fd_set {
850+
fn clone(&self) -> fd_set {
851+
*self
852+
}
853+
}
854+
855+
#[repr(C)]
856+
#[derive(Copy, Clone)]
857+
pub struct timeval {
858+
pub tv_sec: c_long,
859+
pub tv_usec: c_long,
860+
}
861+
840862
extern "system" {
841863
pub fn WSAStartup(wVersionRequested: WORD,
842864
lpWSAData: LPWSADATA) -> c_int;
@@ -1125,6 +1147,11 @@ extern "system" {
11251147
lpOverlapped: LPOVERLAPPED,
11261148
lpNumberOfBytesTransferred: LPDWORD,
11271149
bWait: BOOL) -> BOOL;
1150+
pub fn select(nfds: c_int,
1151+
readfds: *mut fd_set,
1152+
writefds: *mut fd_set,
1153+
exceptfds: *mut fd_set,
1154+
timeout: *const timeval) -> c_int;
11281155
}
11291156

11301157
// Functions that aren't available on Windows XP, but we still use them and just

src/libstd/sys/windows/net.rs

+55-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
use cmp;
1414
use io::{self, Read};
15-
use libc::{c_int, c_void, c_ulong};
15+
use libc::{c_int, c_void, c_ulong, c_long};
1616
use mem;
1717
use net::{SocketAddr, Shutdown};
1818
use ptr;
@@ -115,6 +115,60 @@ impl Socket {
115115
Ok(socket)
116116
}
117117

118+
pub fn connect_timeout(&self, addr: &SocketAddr, timeout: Duration) -> io::Result<()> {
119+
self.set_nonblocking(true)?;
120+
let r = unsafe {
121+
let (addrp, len) = addr.into_inner();
122+
cvt(c::connect(self.0, addrp, len))
123+
};
124+
self.set_nonblocking(false)?;
125+
126+
match r {
127+
Ok(_) => return Ok(()),
128+
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {}
129+
Err(e) => return Err(e),
130+
}
131+
132+
if timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
133+
return Err(io::Error::new(io::ErrorKind::InvalidInput,
134+
"cannot set a 0 duration timeout"));
135+
}
136+
137+
let mut timeout = c::timeval {
138+
tv_sec: timeout.as_secs() as c_long,
139+
tv_usec: (timeout.subsec_nanos() / 1000) as c_long,
140+
};
141+
if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
142+
timeout.tv_usec = 1;
143+
}
144+
145+
let fds = unsafe {
146+
let mut fds = mem::zeroed::<c::fd_set>();
147+
fds.fd_count = 1;
148+
fds.fd_array[0] = self.0;
149+
fds
150+
};
151+
152+
let mut writefds = fds;
153+
let mut errorfds = fds;
154+
155+
let n = unsafe {
156+
cvt(c::select(1, ptr::null_mut(), &mut writefds, &mut errorfds, &timeout))?
157+
};
158+
159+
match n {
160+
0 => Err(io::Error::new(io::ErrorKind::TimedOut, "connection timed out")),
161+
_ => {
162+
if writefds.fd_count != 1 {
163+
if let Some(e) = self.take_error()? {
164+
return Err(e);
165+
}
166+
}
167+
Ok(())
168+
}
169+
}
170+
}
171+
118172
pub fn accept(&self, storage: *mut c::SOCKADDR,
119173
len: *mut c_int) -> io::Result<Socket> {
120174
let socket = unsafe {

src/libstd/sys_common/net.rs

+8
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,14 @@ impl TcpStream {
215215
Ok(TcpStream { inner: sock })
216216
}
217217

218+
pub fn connect_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
219+
init();
220+
221+
let sock = Socket::new(addr, c::SOCK_STREAM)?;
222+
sock.connect_timeout(addr, timeout)?;
223+
Ok(TcpStream { inner: sock })
224+
}
225+
218226
pub fn socket(&self) -> &Socket { &self.inner }
219227

220228
pub fn into_socket(self) -> Socket { self.inner }

0 commit comments

Comments
 (0)