Skip to content

Commit b2ea40b

Browse files
authored
net: add handling for abstract socket name (#6772)
1 parent f602eae commit b2ea40b

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

tokio/src/net/unix/listener.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@ use crate::net::unix::{SocketAddr, UnixStream};
33

44
use std::fmt;
55
use std::io;
6+
#[cfg(target_os = "android")]
7+
use std::os::android::net::SocketAddrExt;
8+
#[cfg(target_os = "linux")]
9+
use std::os::linux::net::SocketAddrExt;
10+
#[cfg(any(target_os = "linux", target_os = "android"))]
11+
use std::os::unix::ffi::OsStrExt;
612
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
7-
use std::os::unix::net;
13+
use std::os::unix::net::{self, SocketAddr as StdSocketAddr};
814
use std::path::Path;
915
use std::task::{Context, Poll};
1016

@@ -70,7 +76,20 @@ impl UnixListener {
7076
where
7177
P: AsRef<Path>,
7278
{
73-
let listener = mio::net::UnixListener::bind(path)?;
79+
// For now, we handle abstract socket paths on linux here.
80+
#[cfg(any(target_os = "linux", target_os = "android"))]
81+
let addr = {
82+
let os_str_bytes = path.as_ref().as_os_str().as_bytes();
83+
if os_str_bytes.starts_with(b"\0") {
84+
StdSocketAddr::from_abstract_name(os_str_bytes)?
85+
} else {
86+
StdSocketAddr::from_pathname(path)?
87+
}
88+
};
89+
#[cfg(not(any(target_os = "linux", target_os = "android")))]
90+
let addr = StdSocketAddr::from_pathname(path)?;
91+
92+
let listener = mio::net::UnixListener::bind_addr(&addr)?;
7493
let io = PollEvented::new(listener)?;
7594
Ok(UnixListener { io })
7695
}

tokio/src/net/unix/stream.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,14 @@ use crate::net::unix::SocketAddr;
88
use std::fmt;
99
use std::io::{self, Read, Write};
1010
use std::net::Shutdown;
11+
#[cfg(target_os = "android")]
12+
use std::os::android::net::SocketAddrExt;
13+
#[cfg(target_os = "linux")]
14+
use std::os::linux::net::SocketAddrExt;
15+
#[cfg(any(target_os = "linux", target_os = "android"))]
16+
use std::os::unix::ffi::OsStrExt;
1117
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
12-
use std::os::unix::net;
18+
use std::os::unix::net::{self, SocketAddr as StdSocketAddr};
1319
use std::path::Path;
1420
use std::pin::Pin;
1521
use std::task::{Context, Poll};
@@ -66,7 +72,20 @@ impl UnixStream {
6672
where
6773
P: AsRef<Path>,
6874
{
69-
let stream = mio::net::UnixStream::connect(path)?;
75+
// On linux, abstract socket paths need to be considered.
76+
#[cfg(any(target_os = "linux", target_os = "android"))]
77+
let addr = {
78+
let os_str_bytes = path.as_ref().as_os_str().as_bytes();
79+
if os_str_bytes.starts_with(b"\0") {
80+
StdSocketAddr::from_abstract_name(os_str_bytes)?
81+
} else {
82+
StdSocketAddr::from_pathname(path)?
83+
}
84+
};
85+
#[cfg(not(any(target_os = "linux", target_os = "android")))]
86+
let addr = StdSocketAddr::from_pathname(path)?;
87+
88+
let stream = mio::net::UnixStream::connect_addr(&addr)?;
7089
let stream = UnixStream::new(stream)?;
7190

7291
poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?;

tokio/tests/uds_stream.rs

+13
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,16 @@ async fn epollhup() -> io::Result<()> {
409409
assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
410410
Ok(())
411411
}
412+
413+
// test for https://github.com/tokio-rs/tokio/issues/6767
414+
#[tokio::test]
415+
#[cfg(any(target_os = "linux", target_os = "android"))]
416+
async fn abstract_socket_name() {
417+
let socket_path = "\0aaa";
418+
let listener = UnixListener::bind(socket_path).unwrap();
419+
420+
let accept = listener.accept();
421+
let connect = UnixStream::connect(&socket_path);
422+
423+
try_join(accept, connect).await.unwrap();
424+
}

0 commit comments

Comments
 (0)