diff --git a/library/std/src/os/windows/mod.rs b/library/std/src/os/windows/mod.rs index f452403ee8426..5740f65bacef1 100644 --- a/library/std/src/os/windows/mod.rs +++ b/library/std/src/os/windows/mod.rs @@ -29,6 +29,7 @@ pub mod ffi; pub mod fs; pub mod io; +pub mod net; pub mod process; pub mod raw; pub mod thread; diff --git a/library/std/src/os/windows/net/addr.rs b/library/std/src/os/windows/net/addr.rs new file mode 100644 index 0000000000000..9e84667785b8c --- /dev/null +++ b/library/std/src/os/windows/net/addr.rs @@ -0,0 +1,168 @@ +use crate::bstr::ByteStr; +use crate::ffi::OsStr; +use crate::path::Path; +#[cfg(not(doc))] +use crate::sys::c::{AF_UNIX, SOCKADDR, SOCKADDR_UN}; +use crate::sys::cvt_nz; +use crate::{fmt, io, mem, ptr}; + +#[cfg(not(doc))] +pub fn sockaddr_un(path: &Path) -> io::Result<(SOCKADDR_UN, usize)> { + // SAFETY: All zeros is a valid representation for `sockaddr_un`. + let mut addr: SOCKADDR_UN = unsafe { mem::zeroed() }; + addr.sun_family = AF_UNIX; + + // path to UTF-8 bytes + let bytes = path + .to_str() + .ok_or(io::const_error!(io::ErrorKind::InvalidInput, "path must be valid UTF-8"))? + .as_bytes(); + if bytes.len() >= addr.sun_path.len() { + return Err(io::const_error!(io::ErrorKind::InvalidInput, "path too long")); + } + // SAFETY: `bytes` and `addr.sun_path` are not overlapping and + // both point to valid memory. + // NOTE: We zeroed the memory above, so the path is already null + // terminated. + unsafe { + ptr::copy_nonoverlapping(bytes.as_ptr(), addr.sun_path.as_mut_ptr().cast(), bytes.len()) + }; + + let len = SUN_PATH_OFFSET + bytes.len() + 1; + Ok((addr, len)) +} +#[cfg(not(doc))] +const SUN_PATH_OFFSET: usize = mem::offset_of!(SOCKADDR_UN, sun_path); +pub struct SocketAddr { + #[cfg(not(doc))] + pub(super) addr: SOCKADDR_UN, + pub(super) len: u32, // Use u32 here as same as libc::socklen_t +} +impl fmt::Debug for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.address() { + AddressKind::Unnamed => write!(fmt, "(unnamed)"), + AddressKind::Abstract(name) => write!(fmt, "{name:?} (abstract)"), + AddressKind::Pathname(path) => write!(fmt, "{path:?} (pathname)"), + } + } +} + +impl SocketAddr { + #[cfg(not(doc))] + pub(super) fn new(f: F) -> io::Result + where + F: FnOnce(*mut SOCKADDR, *mut i32) -> i32, + { + unsafe { + let mut addr: SOCKADDR_UN = mem::zeroed(); + let mut len = mem::size_of::() as i32; + cvt_nz(f(&raw mut addr as *mut _, &mut len))?; + SocketAddr::from_parts(addr, len) + } + } + #[cfg(not(doc))] + pub(super) fn from_parts(addr: SOCKADDR_UN, len: i32) -> io::Result { + if addr.sun_family != AF_UNIX { + Err(io::const_error!(io::ErrorKind::InvalidInput, "invalid address family")) + } else if len < SUN_PATH_OFFSET as _ || len > mem::size_of::() as _ { + Err(io::const_error!(io::ErrorKind::InvalidInput, "invalid address length")) + } else { + Ok(SocketAddr { addr, len: len as _ }) + } + } + + /// Returns the contents of this address if it is a `pathname` address. + /// + /// # Examples + /// + /// With a pathname: + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// use std::path::Path; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixListener::bind("/tmp/sock")?; + /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// assert_eq!(addr.as_pathname(), Some(Path::new("/tmp/sock"))); + /// Ok(()) + /// } + /// ``` + pub fn as_pathname(&self) -> Option<&Path> { + if let AddressKind::Pathname(path) = self.address() { Some(path) } else { None } + } + + /// Constructs a `SockAddr` with the family `AF_UNIX` and the provided path. + /// + /// # Errors + /// + /// Returns an error if the path is longer than `SUN_LEN` or if it contains + /// NULL bytes. + /// + /// # Examples + /// + /// ``` + /// use std::os::windows::net::SocketAddr; + /// use std::path::Path; + /// + /// # fn main() -> std::io::Result<()> { + /// let address = SocketAddr::from_pathname("/path/to/socket")?; + /// assert_eq!(address.as_pathname(), Some(Path::new("/path/to/socket"))); + /// # Ok(()) + /// # } + /// ``` + /// + /// Creating a `SocketAddr` with a NULL byte results in an error. + /// + /// ``` + /// use std::os::windows::net::SocketAddr; + /// + /// assert!(SocketAddr::from_pathname("/path/with/\0/bytes").is_err()); + /// ``` + pub fn from_pathname

(path: P) -> io::Result + where + P: AsRef, + { + sockaddr_un(path.as_ref()).map(|(addr, len)| SocketAddr { addr, len: len as _ }) + } + fn address(&self) -> AddressKind<'_> { + let len = self.len as usize - SUN_PATH_OFFSET; + let path = unsafe { mem::transmute::<&[i8], &[u8]>(&self.addr.sun_path) }; + + if len == 0 { + AddressKind::Unnamed + } else if self.addr.sun_path[0] == 0 { + AddressKind::Abstract(ByteStr::from_bytes(&path[1..len])) + } else { + AddressKind::Pathname(unsafe { + OsStr::from_encoded_bytes_unchecked(&path[..len - 1]).as_ref() + }) + } + } + + /// Returns `true` if the address is unnamed. + /// + /// # Examples + /// + /// A named address: + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixListener::bind("/tmp/sock")?; + /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// assert_eq!(addr.is_unnamed(), false); + /// Ok(()) + /// } + /// ``` + pub fn is_unnamed(&self) -> bool { + matches!(self.address(), AddressKind::Unnamed) + } +} +enum AddressKind<'a> { + Unnamed, + Pathname(&'a Path), + Abstract(&'a ByteStr), +} diff --git a/library/std/src/os/windows/net/listener.rs b/library/std/src/os/windows/net/listener.rs new file mode 100644 index 0000000000000..733d8dcc0481f --- /dev/null +++ b/library/std/src/os/windows/net/listener.rs @@ -0,0 +1,331 @@ +use crate::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; +use crate::os::windows::net::{SocketAddr, UnixStream}; +use crate::path::Path; +#[cfg(not(doc))] +use crate::sys::c::{AF_UNIX, SOCK_STREAM, SOCKADDR_UN, bind, getsockname, listen}; +use crate::sys::net::Socket; +#[cfg(not(doc))] +use crate::sys::winsock::startup; +use crate::sys::{AsInner, cvt_nz}; +use crate::{fmt, io}; + +/// A structure representing a Unix domain socket server. +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use std::os::windows::net::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// } +/// +/// fn main() -> std::io::Result<()> { +/// let listener = UnixListener::bind("/path/to/the/socket")?; +/// +/// // accept connections and process them, spawning a new thread for each one +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// /* connection succeeded */ +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// /* connection failed */ +/// break; +/// } +/// } +/// } +/// Ok(()) +/// } +/// ``` +pub struct UnixListener(Socket); + +impl fmt::Debug for UnixListener { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixListener"); + builder.field("sock", self.0.as_inner()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + builder.finish() + } +} +impl UnixListener { + /// Creates a new `UnixListener` bound to the specified socket. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// let listener = match UnixListener::bind("/path/to/the/socket") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {e:?}"); + /// return + /// } + /// }; + /// ``` + pub fn bind>(path: P) -> io::Result { + let socket_addr = SocketAddr::from_pathname(path)?; + Self::bind_addr(&socket_addr) + } + + /// Creates a new `UnixListener` bound to the specified [`socket address`]. + /// + /// [`socket address`]: crate::os::windows::net::SocketAddr + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::{UnixListener}; + /// + /// fn main() -> std::io::Result<()> { + /// let listener1 = UnixListener::bind("path/to/socket")?; + /// let addr = listener1.local_addr()?; + /// + /// let listener2 = match UnixListener::bind_addr(&addr) { + /// Ok(sock) => sock, + /// Err(err) => { + /// println!("Couldn't bind: {err:?}"); + /// return Err(err); + /// } + /// }; + /// Ok(()) + /// } + /// ``` + pub fn bind_addr(socket_addr: &SocketAddr) -> io::Result { + startup(); + let inner = Socket::new(AF_UNIX as _, SOCK_STREAM)?; + unsafe { + cvt_nz(bind(inner.as_raw(), &raw const socket_addr.addr as _, socket_addr.len as _))?; + cvt_nz(listen(inner.as_raw(), 128))?; + } + Ok(UnixListener(inner)) + } + + /// Accepts a new incoming connection to this listener. + /// + /// This function will block the calling thread until a new Unix connection + /// is established. When established, the corresponding [`UnixStream`] and + /// the remote peer's address will be returned. + /// + /// [`UnixStream`]: crate::os::windows::net::UnixStream + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// + /// match listener.accept() { + /// Ok((socket, addr)) => println!("Got a client: {addr:?}"), + /// Err(e) => println!("accept function failed: {e:?}"), + /// } + /// Ok(()) + /// } + /// ``` + pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> { + let mut storage = SOCKADDR_UN::default(); + let mut len = size_of::() as _; + let inner = self.0.accept(&raw mut storage as *mut _, &raw mut len)?; + let addr = SocketAddr::from_parts(storage, len)?; + Ok((UnixStream(inner), addr)) + } + + /// Returns the local socket address of this listener. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// let addr = listener.local_addr().expect("Couldn't get local address"); + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw(), addr, len) }) + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixListener` is a reference to the same socket that this + /// object references. Both handles can be used to accept incoming + /// connections and options set on one listener will affect the other. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// let listener_copy = listener.try_clone().expect("try_clone failed"); + /// Ok(()) + /// } + /// ``` + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixListener) + } + + /// Moves the socket into or out of nonblocking mode. + /// + /// This will result in the `accept` operation becoming nonblocking, + /// i.e., immediately returning from their calls. If the IO operation is + /// successful, `Ok` is returned and no further action is required. If the + /// IO operation could not be completed and needs to be retried, an error + /// with kind [`io::ErrorKind::WouldBlock`] is returned. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// listener.set_nonblocking(true).expect("Couldn't set non blocking"); + /// Ok(()) + /// } + /// ``` + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixListener; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/tmp/sock")?; + /// + /// if let Ok(Some(err)) = listener.take_error() { + /// println!("Got error: {err:?}"); + /// } + /// Ok(()) + /// } + /// ``` + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Returns an iterator over incoming connections. + /// + /// The iterator will never return [`None`] and will also not yield the + /// peer's [`SocketAddr`] structure. + /// + /// # Examples + /// + /// ```no_run + /// use std::thread; + /// use std::os::windows::net::{UnixStream, UnixListener}; + /// + /// fn handle_client(stream: UnixStream) { + /// // ... + /// } + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// + /// for stream in listener.incoming() { + /// match stream { + /// Ok(stream) => { + /// thread::spawn(|| handle_client(stream)); + /// } + /// Err(err) => { + /// break; + /// } + /// } + /// } + /// Ok(()) + /// } + /// ``` + pub fn incoming(&self) -> Incoming<'_> { + Incoming { listener: self } + } +} + +/// An iterator over incoming connections to a [`UnixListener`]. +/// +/// It will never return [`None`]. +/// +/// # Examples +/// +/// ```no_run +/// use std::thread; +/// use std::os::windows::net::{UnixStream, UnixListener}; +/// +/// fn handle_client(stream: UnixStream) { +/// // ... +/// } +/// +/// fn main() -> std::io::Result<()> { +/// let listener = UnixListener::bind("/path/to/the/socket")?; +/// +/// for stream in listener.incoming() { +/// match stream { +/// Ok(stream) => { +/// thread::spawn(|| handle_client(stream)); +/// } +/// Err(err) => { +/// break; +/// } +/// } +/// } +/// Ok(()) +/// } +/// ``` +pub struct Incoming<'a> { + listener: &'a UnixListener, +} + +impl<'a> Iterator for Incoming<'a> { + type Item = io::Result; + + fn next(&mut self) -> Option> { + Some(self.listener.accept().map(|s| s.0)) + } + + fn size_hint(&self) -> (usize, Option) { + (usize::MAX, None) + } +} + +impl AsRawSocket for UnixListener { + #[inline] + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixListener { + #[inline] + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + UnixListener(unsafe { Socket::from_raw_socket(sock) }) + } +} + +impl IntoRawSocket for UnixListener { + #[inline] + fn into_raw_socket(self) -> RawSocket { + self.0.into_raw_socket() + } +} + +impl<'a> IntoIterator for &'a UnixListener { + type Item = io::Result; + type IntoIter = Incoming<'a>; + + fn into_iter(self) -> Incoming<'a> { + self.incoming() + } +} diff --git a/library/std/src/os/windows/net/mod.rs b/library/std/src/os/windows/net/mod.rs new file mode 100644 index 0000000000000..ac6420cd9949a --- /dev/null +++ b/library/std/src/os/windows/net/mod.rs @@ -0,0 +1,7 @@ +#![unstable(feature = "windows_unix_domain_sockets", issue = "150487")] +mod addr; +mod listener; +mod stream; +pub use addr::*; +pub use listener::*; +pub use stream::*; diff --git a/library/std/src/os/windows/net/stream.rs b/library/std/src/os/windows/net/stream.rs new file mode 100644 index 0000000000000..ddc939e66dd18 --- /dev/null +++ b/library/std/src/os/windows/net/stream.rs @@ -0,0 +1,405 @@ +use crate::net::Shutdown; +use crate::os::windows::io::{ + AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, RawSocket, +}; +use crate::os::windows::net::SocketAddr; +use crate::path::Path; +#[cfg(not(doc))] +use crate::sys::c::{ + AF_UNIX, SO_RCVTIMEO, SO_SNDTIMEO, SOCK_STREAM, connect, getpeername, getsockname, +}; +use crate::sys::net::Socket; +#[cfg(not(doc))] +use crate::sys::winsock::startup; +use crate::sys::{AsInner, cvt_nz}; +use crate::time::Duration; +use crate::{fmt, io}; +/// A Unix stream socket. +/// +/// # Examples +/// +/// ```no_run +/// use std::os::windows::net::UnixStream; +/// use std::io::prelude::*; +/// +/// fn main() -> std::io::Result<()> { +/// let mut stream = UnixStream::connect("/path/to/my/socket")?; +/// stream.write_all(b"hello world")?; +/// let mut response = String::new(); +/// stream.read_to_string(&mut response)?; +/// println!("{response}"); +/// Ok(()) +/// } +/// ``` +pub struct UnixStream(pub(super) Socket); +impl fmt::Debug for UnixStream { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut builder = fmt.debug_struct("UnixStream"); + builder.field("sock", self.0.as_inner()); + if let Ok(addr) = self.local_addr() { + builder.field("local", &addr); + } + if let Ok(addr) = self.peer_addr() { + builder.field("peer", &addr); + } + builder.finish() + } +} +impl UnixStream { + /// Connects to the socket named by `path`. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// let socket = match UnixStream::connect("/tmp/sock") { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {e:?}"); + /// return + /// } + /// }; + /// ``` + pub fn connect>(path: P) -> io::Result { + let socket_addr = SocketAddr::from_pathname(path)?; + Self::connect_addr(&socket_addr) + } + + /// Connects to the socket specified by [`address`]. + /// + /// [`address`]: crate::os::windows::net::SocketAddr + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::{UnixListener, UnixStream}; + /// + /// fn main() -> std::io::Result<()> { + /// let listener = UnixListener::bind("/path/to/the/socket")?; + /// let addr = listener.local_addr()?; + /// + /// let sock = match UnixStream::connect_addr(&addr) { + /// Ok(sock) => sock, + /// Err(e) => { + /// println!("Couldn't connect: {e:?}"); + /// return Err(e) + /// } + /// }; + /// Ok(()) + /// } + /// ```` + pub fn connect_addr(socket_addr: &SocketAddr) -> io::Result { + startup(); + let inner = Socket::new(AF_UNIX as _, SOCK_STREAM)?; + unsafe { + cvt_nz(connect( + inner.as_raw(), + &raw const socket_addr.addr as *const _, + socket_addr.len as _, + ))?; + } + Ok(UnixStream(inner)) + } + + /// Returns the socket address of the local half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// let addr = socket.local_addr().expect("Couldn't get local address"); + /// Ok(()) + /// } + /// ``` + pub fn local_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getsockname(self.0.as_raw(), addr, len) }) + } + + /// Returns the socket address of the remote half of this connection. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// let addr = socket.peer_addr().expect("Couldn't get peer address"); + /// Ok(()) + /// } + /// ``` + pub fn peer_addr(&self) -> io::Result { + SocketAddr::new(|addr, len| unsafe { getpeername(self.0.as_raw(), addr, len) }) + } + + /// Returns the read timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.set_read_timeout(Some(Duration::new(1, 0))).expect("Couldn't set read timeout"); + /// assert_eq!(socket.read_timeout()?, Some(Duration::new(1, 0))); + /// Ok(()) + /// } + /// ``` + pub fn read_timeout(&self) -> io::Result> { + self.0.timeout(SO_RCVTIMEO) + } + + /// Moves the socket into or out of nonblocking mode. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.set_nonblocking(true).expect("Couldn't set nonblocking"); + /// Ok(()) + /// } + /// ``` + pub fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> { + self.0.set_nonblocking(nonblocking) + } + + /// Sets the read timeout for the socket. + /// + /// If the provided value is [`None`], then [`read`] calls will block + /// indefinitely. An [`Err`] is returned if the zero [`Duration`] is passed to this + /// method. + /// + /// [`read`]: io::Read::read + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.set_read_timeout(Some(Duration::new(1, 0))).expect("Couldn't set read timeout"); + /// Ok(()) + /// } + /// ``` + /// + /// An [`Err`] is returned if the zero [`Duration`] is passed to this + /// method: + /// + /// ```no_run + /// use std::io; + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// let result = socket.set_read_timeout(Some(Duration::new(0, 0))); + /// let err = result.unwrap_err(); + /// assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + /// Ok(()) + /// } + /// ``` + pub fn set_read_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_RCVTIMEO) + } + + /// Sets the write timeout for the socket. + /// + /// If the provided value is [`None`], then [`write`] calls will block + /// indefinitely. An [`Err`] is returned if the zero [`Duration`] is + /// passed to this method. + /// + /// [`read`]: io::Read::read + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.set_write_timeout(Some(Duration::new(1, 0))) + /// .expect("Couldn't set write timeout"); + /// Ok(()) + /// } + /// ``` + /// + /// An [`Err`] is returned if the zero [`Duration`] is passed to this + /// method: + /// + /// ```no_run + /// use std::io; + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// let result = socket.set_write_timeout(Some(Duration::new(0, 0))); + /// let err = result.unwrap_err(); + /// assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + /// Ok(()) + /// } + /// ``` + pub fn set_write_timeout(&self, dur: Option) -> io::Result<()> { + self.0.set_timeout(dur, SO_SNDTIMEO) + } + + /// Shuts down the read, write, or both halves of this connection. + /// + /// This function will cause all pending and future I/O calls on the + /// specified portions to immediately return with an appropriate value + /// (see the documentation of [`Shutdown`]). + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// use std::net::Shutdown; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.shutdown(Shutdown::Both).expect("shutdown function failed"); + /// Ok(()) + /// } + /// ``` + pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { + self.0.shutdown(how) + } + + /// Returns the value of the `SO_ERROR` option. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// if let Ok(Some(err)) = socket.take_error() { + /// println!("Got error: {err:?}"); + /// } + /// Ok(()) + /// } + /// ``` + pub fn take_error(&self) -> io::Result> { + self.0.take_error() + } + + /// Creates a new independently owned handle to the underlying socket. + /// + /// The returned `UnixStream` is a reference to the same stream that this + /// object references. Both handles will read and write the same stream of + /// data, and options set on one stream will be propagated to the other + /// stream. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// let sock_copy = socket.try_clone().expect("Couldn't clone socket"); + /// Ok(()) + /// } + /// ``` + pub fn try_clone(&self) -> io::Result { + self.0.duplicate().map(UnixStream) + } + + /// Returns the write timeout of this socket. + /// + /// # Examples + /// + /// ```no_run + /// use std::os::windows::net::UnixStream; + /// use std::time::Duration; + /// + /// fn main() -> std::io::Result<()> { + /// let socket = UnixStream::connect("/tmp/sock")?; + /// socket.set_write_timeout(Some(Duration::new(1, 0))) + /// .expect("Couldn't set write timeout"); + /// assert_eq!(socket.write_timeout()?, Some(Duration::new(1, 0))); + /// Ok(()) + /// } + /// ``` + pub fn write_timeout(&self) -> io::Result> { + self.0.timeout(SO_SNDTIMEO) + } +} + +impl io::Read for UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + io::Read::read(&mut &*self, buf) + } +} + +impl<'a> io::Read for &'a UnixStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + self.0.read(buf) + } +} + +impl io::Write for UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + io::Write::write(&mut &*self, buf) + } + + fn flush(&mut self) -> io::Result<()> { + io::Write::flush(&mut &*self) + } +} +impl<'a> io::Write for &'a UnixStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_vectored(&[io::IoSlice::new(buf)]) + } + #[inline] + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + self.0.write_vectored(bufs) + } + #[inline] + fn is_write_vectored(&self) -> bool { + self.0.is_write_vectored() + } +} + +impl AsSocket for UnixStream { + #[inline] + fn as_socket(&self) -> BorrowedSocket<'_> { + self.0.as_socket() + } +} + +impl AsRawSocket for UnixStream { + #[inline] + fn as_raw_socket(&self) -> RawSocket { + self.0.as_raw_socket() + } +} + +impl FromRawSocket for UnixStream { + #[inline] + unsafe fn from_raw_socket(sock: RawSocket) -> Self { + unsafe { UnixStream(Socket::from_raw_socket(sock)) } + } +} + +impl IntoRawSocket for UnixStream { + fn into_raw_socket(self) -> RawSocket { + self.0.into_raw_socket() + } +} diff --git a/library/std/src/sys/pal/windows/mod.rs b/library/std/src/sys/pal/windows/mod.rs index 32bd6ea3a4f6c..af71f7be56408 100644 --- a/library/std/src/sys/pal/windows/mod.rs +++ b/library/std/src/sys/pal/windows/mod.rs @@ -311,6 +311,11 @@ pub fn cvt(i: I) -> crate::io::Result { if i.is_zero() { Err(crate::io::Error::last_os_error()) } else { Ok(i) } } +#[allow(dead_code)] +pub fn cvt_nz(i: I) -> crate::io::Result<()> { + if i.is_zero() { Ok(()) } else { Err(crate::io::Error::last_os_error()) } +} + pub fn dur2timeout(dur: Duration) -> u32 { // Note that a duration is a (u64, u32) (seconds, nanoseconds) pair, and the // timeouts in windows APIs are typically u32 milliseconds. To translate, we diff --git a/library/std/tests/windows_unix_socket.rs b/library/std/tests/windows_unix_socket.rs new file mode 100644 index 0000000000000..18f4c52e72c28 --- /dev/null +++ b/library/std/tests/windows_unix_socket.rs @@ -0,0 +1,86 @@ +#![cfg(windows)] +#![feature(windows_unix_domain_sockets)] +// Now only test windows_unix_domain_sockets feature +// in the future, will test both unix and windows uds +use std::io::{Read, Write}; +use std::os::windows::net::{UnixListener, UnixStream}; +use std::thread; + +#[test] +fn win_uds_smoke_bind_connect() { + let tmp = std::env::temp_dir(); + let sock_path = tmp.join("rust-test-uds-smoke.sock"); + let _ = std::fs::remove_file(&sock_path); + let listener = UnixListener::bind(&sock_path).expect("bind failed"); + let sock_path_clone = sock_path.clone(); + let tx = thread::spawn(move || { + let mut stream = UnixStream::connect(&sock_path_clone).expect("connect failed"); + stream.write_all(b"hello").expect("write failed"); + }); + + let (mut stream, _) = listener.accept().expect("accept failed"); + let mut buf = [0; 5]; + stream.read_exact(&mut buf).expect("read failed"); + assert_eq!(&buf, b"hello"); + + tx.join().unwrap(); + + drop(listener); + let _ = std::fs::remove_file(&sock_path); +} + +#[test] +fn win_uds_echo() { + let tmp = std::env::temp_dir(); + let sock_path = tmp.join("rust-test-uds-echo.sock"); + let _ = std::fs::remove_file(&sock_path); + + let listener = UnixListener::bind(&sock_path).expect("bind failed"); + let srv = thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept failed"); + let mut buf = [0u8; 128]; + loop { + let n = match stream.read(&mut buf) { + Ok(0) => break, + Ok(n) => n, + Err(e) => panic!("read error: {}", e), + }; + stream.write_all(&buf[..n]).expect("write_all failed"); + } + }); + + let sock_path_clone = sock_path.clone(); + let cli = thread::spawn(move || { + let mut stream = UnixStream::connect(&sock_path_clone).expect("connect failed"); + let req = b"hello windows uds"; + stream.write_all(req).expect("write failed"); + let mut resp = vec![0u8; req.len()]; + stream.read_exact(&mut resp).expect("read failed"); + assert_eq!(resp, req); + }); + + cli.join().unwrap(); + srv.join().unwrap(); + + let _ = std::fs::remove_file(&sock_path); +} + +#[test] +fn win_uds_path_too_long() { + let tmp = std::env::temp_dir(); + let long_path = tmp.join("a".repeat(200)); + let result = UnixListener::bind(&long_path); + assert!(result.is_err()); + let _ = std::fs::remove_file(&long_path); +} +#[test] +fn win_uds_existing_bind() { + let tmp = std::env::temp_dir(); + let sock_path = tmp.join("rust-test-uds-existing.sock"); + let _ = std::fs::remove_file(&sock_path); + let listener = UnixListener::bind(&sock_path).expect("bind failed"); + let result = UnixListener::bind(&sock_path); + assert!(result.is_err()); + drop(listener); + let _ = std::fs::remove_file(&sock_path); +}