diff --git a/Cargo.toml b/Cargo.toml index 87be173..9ab6acd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,16 +8,21 @@ repository = "https://github.com/sfackler/rust-unix-socket" documentation = "https://sfackler.github.io/rust-unix-socket/doc/v0.4.5/unix_socket" readme = "README.md" keywords = ["posix", "unix", "socket", "domain"] +build = "build.rs" [dependencies] libc = "0.1" debug-builders = "0.1" +[build-dependencies] +gcc = "0.3" + [dev-dependencies] tempdir = "0.3" [features] -default = ["from_raw_fd"] +default = ["from_raw_fd", "sendmsg"] from_raw_fd = [] socket_timeout = [] +sendmsg = [] diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..0605860 --- /dev/null +++ b/build.rs @@ -0,0 +1,5 @@ +extern crate gcc; + +fn main() { + gcc::compile_library("libcmsg_manip.a", &["src/cmsg_manip/cmsg.c"]); +} diff --git a/examples/socket_send.rs b/examples/socket_send.rs new file mode 100644 index 0000000..5b3478a --- /dev/null +++ b/examples/socket_send.rs @@ -0,0 +1,84 @@ +extern crate libc; +extern crate unix_socket; + +use std::os::unix::io::{AsRawFd, FromRawFd}; +use std::path::Path; + +#[cfg(feature = "sendmsg")] +use unix_socket::{ControlMsg, UCred, UnixDatagram, RecvMsgFlags, SendMsgFlags}; + +#[cfg(feature = "sendmsg")] +fn handle_parent(sock: UnixDatagram) { + let (parent2, child2) = UnixDatagram::pair().unwrap(); + + let cmsg = ControlMsg::Rights(vec![child2.as_raw_fd()]); + let cmsg2 = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + println!("cmsg {:?}", cmsg2); + let sent_bytes = sock.sendmsg::<&Path>(None, &[&[]], &[cmsg, cmsg2], SendMsgFlags::new()).unwrap(); + assert_eq!(sent_bytes, 0); + drop(child2); + println!("Parent sent child SCM_RIGHTS fd"); + + let mut buf = &mut [0u8; 4096]; + let read = parent2.recv(buf).unwrap(); + assert_eq!(&buf[..read], "Hello, world!".as_bytes()); + println!("Parent received message from child via SCM_RIGHTS fd"); +} + +#[cfg(feature = "sendmsg")] +fn handle_child(sock: UnixDatagram) { + sock.set_passcred(true).unwrap(); + let flags = RecvMsgFlags::new().cmsg_cloexec(true); + let result = sock.recvmsg(&[&mut[]], flags).unwrap(); + assert_eq!(result.control_msgs.len(), 2); + + let mut new_sock = None; + let mut creds = None; + for cmsg in result.control_msgs { + match cmsg.clone() { + ControlMsg::Rights(fds) => { + assert!(new_sock.is_none()); + assert_eq!(fds.len(), 1); + unsafe { + new_sock = Some(UnixDatagram::from_raw_fd(fds[0])); + } + println!("Child received SCM_RIGHTS fd"); + }, + ControlMsg::Credentials(ucred) => { + assert!(creds.is_none()); + creds = Some(ucred); + println!("Child received SCM_CREDENTIALS"); + }, + _ => unreachable!(), + } + } + + let creds = creds.unwrap(); + unsafe { + assert_eq!(creds.uid, libc::getuid()); + assert_eq!(creds.gid, libc::getgid()); + assert!(creds.pid != 0); + } + let sent = new_sock.unwrap().send("Hello, world!".as_bytes()).unwrap(); + println!("Child sent message to parent via SCM_RIGHTS fd"); + assert_eq!(sent, 13); +} + +#[cfg(feature = "sendmsg")] +fn main() { + let (parent_sock, child_sock) = UnixDatagram::pair().unwrap(); + let pid = unsafe { libc::fork() }; + if pid == 0 { + handle_child(child_sock); + } else { + handle_parent(parent_sock); + } +} + +#[cfg(not(feature = "sendmsg"))] +fn main() { +} diff --git a/src/cmsg_manip/cmsg.c b/src/cmsg_manip/cmsg.c new file mode 100644 index 0000000..3a0221c --- /dev/null +++ b/src/cmsg_manip/cmsg.c @@ -0,0 +1,46 @@ +// Need to use GNU_SOURCE for ucred struct +#define _GNU_SOURCE +#include + +size_t cmsghdr_size = sizeof(struct cmsghdr); +size_t iovec_size = sizeof(struct iovec); +size_t msghdr_size = sizeof(struct msghdr); +size_t ucred_size = sizeof(struct ucred); + +int scm_credentials = SCM_CREDENTIALS; +int scm_rights = SCM_RIGHTS; +int so_passcred = SO_PASSCRED; + +int msg_eor = MSG_EOR; +int msg_trunc = MSG_TRUNC; +int msg_ctrunc = MSG_CTRUNC; +int msg_errqueue = MSG_ERRQUEUE; +int msg_dontwait = MSG_DONTWAIT; +int msg_cmsg_cloexec = MSG_CMSG_CLOEXEC; +int msg_nosignal = MSG_NOSIGNAL; +int msg_peek = MSG_PEEK; +int msg_waitall = MSG_WAITALL; + +struct cmsghdr * cmsg_firsthdr(struct msghdr *msgh) { + return CMSG_FIRSTHDR(msgh); +} + +struct cmsghdr * cmsg_nxthdr(struct msghdr *msgh, struct cmsghdr *cmsg) { + return CMSG_NXTHDR(msgh, cmsg); +} + +size_t cmsg_align(size_t length) { + return CMSG_ALIGN(length); +} + +size_t cmsg_space(size_t length) { + return CMSG_SPACE(length); +} + +size_t cmsg_len(size_t length) { + return CMSG_LEN(length); +} + +unsigned char * cmsg_data(struct cmsghdr *cmsg) { + return CMSG_DATA(cmsg); +} diff --git a/src/lib.rs b/src/lib.rs index 89c4ad8..97b68fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,11 @@ use std::fmt; use std::path::Path; use std::mem::size_of; +#[cfg(feature = "sendmsg")] +mod sendmsg_impl; +#[cfg(feature = "sendmsg")] +pub use sendmsg_impl::{ControlMsg, UCred, SendMsgFlags, RecvMsgFlags, RecvMsgResultFlags}; + extern "C" { fn socketpair(domain: libc::c_int, ty: libc::c_int, @@ -168,6 +173,19 @@ impl Inner { .map(|_| ()) } } + + #[cfg(feature = "sendmsg")] + fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { + unsafe { + let v: libc::c_int = receive_creds as libc::c_int; + cvt(libc::setsockopt(self.0, + libc::SOL_SOCKET, + sendmsg_impl::SO_PASSCRED, + &v as *const libc::c_int as *const libc::c_void, + mem::size_of::() as libc::socklen_t)) + .map(|_| ()) + } + } } unsafe fn sockaddr_un>(path: P) @@ -641,6 +659,19 @@ impl<'a> Iterator for Incoming<'a> { } } +/// The return value from a call to recvmsg +#[cfg(feature = "sendmsg")] +pub struct RecvMsgResult { + /// Number of bytes received + pub data_bytes: usize, + /// Address of the sender + pub sender: SocketAddr, + /// List of all control messages received during this call + pub control_msgs: Vec, + /// Flags returned by recvmsg, see the struct definition for more details + pub flags: RecvMsgResultFlags, +} + /// A Unix datagram socket. /// /// # Examples @@ -771,6 +802,43 @@ impl UnixDatagram { } } + /// Receives data on the socket. + /// + /// This interface allows receiving data into multiple buffers. This acts as if the buffers had been + /// concatenated in the order they were given. + /// + /// On success, returns the number of bytes written. + #[cfg(feature = "sendmsg")] + pub fn recvmsg(&self, buffers: &[&mut[u8]], flags: RecvMsgFlags) -> io::Result { + let mut result = Err(io::Error::new(io::ErrorKind::Other, "programming error")); + let addr = try!(SocketAddr::new(|addr, len| { + const CMSG_BUFFER_SIZE: usize = 4096; + let mut cmsg_buffer = [0u8; CMSG_BUFFER_SIZE]; + unsafe { + result = sendmsg_impl::recvmsg( + self.inner.0, + buffers, + &mut cmsg_buffer, + flags, + addr, + len); + } + if let Err(ref e) = result { + -(e.raw_os_error().unwrap() as libc::c_int) + } else { + 0 + } + })); + + let result = try!(result); + Ok(RecvMsgResult { + data_bytes: result.data_bytes, + sender: addr, + control_msgs: result.control_msgs, + flags: result.flags, + }) + } + /// Sends data on the socket to the specified address. /// /// On success, returns the number of bytes written. @@ -804,6 +872,37 @@ impl UnixDatagram { } } + /// Sends data on the socket to the specified address. + /// + /// If path is None, the peer address set by the `connect` method will be used. If it has not + /// been set, then this method will return an error. + /// + /// This interface allows sending data from multiple buffers. This acts as if the buffers had been + /// concatenated in the order they were given. + /// + /// ctrl_msgs are special ancillary data that can be sent, such as file descriptors and Unix credentials + /// + /// On success, returns the number of bytes written. + #[cfg(feature = "sendmsg")] + pub fn sendmsg>(&self, path: Option

, buffers: &[&[u8]], ctrl_msgs: &[ControlMsg], flags: SendMsgFlags) -> io::Result { + unsafe { + let dst = match path { + None => None, + Some(p) => { + let v = try!(sockaddr_un(p)); + Some(v) + }, + }; + + sendmsg_impl::sendmsg( + self.inner.0, + dst, + buffers, + ctrl_msgs, + flags) + } + } + /// Sets the read timeout for the socket. /// /// If the provided value is `None`, then `recv` and `recv_from` calls will @@ -852,6 +951,12 @@ impl UnixDatagram { pub fn shutdown(&self, how: Shutdown) -> io::Result<()> { self.inner.shutdown(how) } + + /// Enable or disable receiving SCM_CREDENTIALS messages + #[cfg(feature = "sendmsg")] + pub fn set_passcred(&self, receive_creds: bool) -> io::Result<()> { + self.inner.set_passcred(receive_creds) + } } impl AsRawFd for UnixDatagram { @@ -872,11 +977,14 @@ impl std::os::unix::io::FromRawFd for UnixDatagram { #[cfg(test)] mod test { + extern crate libc; extern crate tempdir; use std::thread; use std::io; use std::io::prelude::*; + use std::os::unix::io::{AsRawFd, FromRawFd}; + use std::path::Path; use self::tempdir::TempDir; use {UnixListener, UnixStream, UnixDatagram, AddressKind}; @@ -1215,4 +1323,165 @@ mod test { thread.join().unwrap(); } + + /// Sends "hello" on the data channel and the specified cmsgs on the control channel + #[cfg(feature = "sendmsg")] + fn sendmsg_helper>(s: &UnixDatagram, dst: Option

, cmsgs: &[super::ControlMsg]) { + use SendMsgFlags; + let msg = b"he"; + let msg2 = b"llo"; + let sent_bytes = or_panic!(s.sendmsg(dst, &[&msg[..], &msg2[..]], cmsgs, SendMsgFlags::new())); + assert_eq!(sent_bytes, 5); + } + + /// Expects to receive "hello" on the data channel, and uses the given buf for cmsgs + #[cfg(feature = "sendmsg")] + fn recvmsg_helper(s: &UnixDatagram) -> super::RecvMsgResult { + use RecvMsgFlags; + let mut buf = [0; 3]; + let mut buf2 = [0; 3]; + let result = or_panic!(s.recvmsg(&[&mut buf[..], &mut buf2[..]], RecvMsgFlags::new())); + assert_eq!(result.data_bytes, 5); + assert_eq!(&buf[..], b"hel"); + assert_eq!(&buf2[..2], b"lo"); + result + } + + #[cfg(feature = "sendmsg")] + #[test] + fn test_sendmsg_to() { + let dir = or_panic!(TempDir::new("unix_socket")); + let path1 = dir.path().join("sock1"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::unbound()); + + // Make sure the path-specified form of sendmsg works + sendmsg_helper(&sock2, Some(&path1), &[]); + let mut buf = [0; 6]; + let size = or_panic!(sock1.recv(&mut buf)); + assert_eq!(size, 5); + assert_eq!(&buf[..5], b"hello"); + } + + #[cfg(feature = "sendmsg")] + #[test] + fn test_recvmsg_sender() { + let dir = or_panic!(TempDir::new("unix_socket")); + let path1 = dir.path().join("sock1"); + let path2 = dir.path().join("sock2"); + + let sock1 = or_panic!(UnixDatagram::bind(&path1)); + let sock2 = or_panic!(UnixDatagram::bind(&path2)); + + assert_eq!(or_panic!(sock1.send_to(b"hello", &path2)), 5); + let result = recvmsg_helper(&sock2); + match result.sender.address() { + AddressKind::Pathname(p) => assert_eq!(p, path1.as_path()), + _ => unreachable!(), + } + } + + #[cfg(feature = "sendmsg")] + #[test] + fn test_send_credentials_without_passcred() { + use {ControlMsg, UCred}; + + // Without passcred, the ucred should be dropped + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + let result = recvmsg_helper(&s1); + assert_eq!(result.control_msgs.len(), 0); + assert!(!result.flags.control_truncated()); + }); + + let cmsg = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + + thread.join().unwrap(); + } + + #[cfg(feature = "sendmsg")] + #[test] + fn test_send_credentials_with_passcred() { + // With passcred, the ucred should be sent. + // Note: SO_PASSCRED will cause a credential to always be sent. Unfortunately, + // without additional capabilities, we cannot properly test the SCM_CREDENTIALS + // message. We pass one through to sendmsg below just to exercise that codepath, + // but it will not demonstrate that we are correctly sending it (other than not + // triggering an EINVAL). + + use {ControlMsg, UCred}; + + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + or_panic!(s1.set_passcred(true)); + let result = recvmsg_helper(&s1); + assert_eq!(result.control_msgs.len(), 1); + assert!(!result.flags.control_truncated()); + + for cmsg in result.control_msgs { + match cmsg { + ControlMsg::Credentials(ucred) => { + unsafe { + assert_eq!(ucred.pid, libc::getpid()); + assert_eq!(ucred.uid, libc::getuid()); + assert_eq!(ucred.gid, libc::getgid()); + } + }, + _ => panic!("Unexpected control message"), + } + } + }); + + let cmsg = unsafe { ControlMsg::Credentials(UCred{ + pid: libc::getpid(), + uid: libc::getuid(), + gid: libc::getgid(), + }) }; + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + + thread.join().unwrap(); + } + + #[cfg(all(feature = "from_raw_fd", feature = "sendmsg"))] + #[test] + fn test_send_fds() { + use ControlMsg; + let (s1, s2) = or_panic!(UnixDatagram::pair()); + let thread = thread::spawn(move || { + let result = recvmsg_helper(&s1); + assert_eq!(result.control_msgs.len(), 1); + assert!(!result.flags.control_truncated()); + + for cmsg in result.control_msgs { + match cmsg { + ControlMsg::Rights(fds) => { + assert_eq!(fds.len(), 1); + let new_s = unsafe { UnixDatagram::from_raw_fd(fds[0]) }; + let mut buf = [0; 4]; + assert_eq!(or_panic!(new_s.recv(&mut buf[..])), 4); + assert_eq!(&buf[..], b"Test"); + }, + _ => panic!("Unexpected control message"), + } + } + }); + + let (my, theirs) = or_panic!(UnixDatagram::pair()); + let cmsg = ControlMsg::Rights(vec![theirs.as_raw_fd()]); + sendmsg_helper::<&Path>(&s2, None, &[cmsg]); + drop(s2); + drop(theirs); + + assert_eq!(or_panic!(my.send(b"Test")), 4); + + thread.join().unwrap(); + } } diff --git a/src/sendmsg_impl.rs b/src/sendmsg_impl.rs new file mode 100644 index 0000000..f7bb188 --- /dev/null +++ b/src/sendmsg_impl.rs @@ -0,0 +1,430 @@ +use std::io; +use std::ptr; +use std::mem; +use std::os::unix::io::RawFd; +use std::slice; + +use libc; + +mod raw { + use libc; + extern "system" { + pub fn sendmsg(socket: libc::c_int, msg: *const libc::c_void, flags: libc::c_int) -> libc::ssize_t; + pub fn recvmsg(socket: libc::c_int, msg: *mut libc::c_void, flags: libc::c_int) -> libc::ssize_t; + } + + #[allow(dead_code)] + extern { + pub static cmsghdr_size: libc::size_t; + pub static iovec_size: libc::size_t; + pub static msghdr_size: libc::size_t; + pub static ucred_size: libc::size_t; + + pub static scm_credentials: libc::c_int; + pub static scm_rights: libc::c_int; + + pub static so_passcred: libc::c_int; + + pub static msg_eor: libc::c_int; + pub static msg_trunc: libc::c_int; + pub static msg_ctrunc: libc::c_int; + pub static msg_errqueue: libc::c_int; + pub static msg_dontwait: libc::c_int; + pub static msg_cmsg_cloexec: libc::c_int; + pub static msg_nosignal: libc::c_int; + pub static msg_peek: libc::c_int; + pub static msg_waitall: libc::c_int; + + pub fn cmsg_firsthdr(msgh: *const libc::c_void) -> *const libc::c_void; + pub fn cmsg_nxthdr(msgh: *const libc::c_void, cmsg: *const libc::c_void) -> *const libc::c_void; + pub fn cmsg_align(len: libc::size_t) -> libc::size_t; + pub fn cmsg_space(len: libc::size_t) -> libc::size_t; + pub fn cmsg_len(len: libc::size_t) -> libc::size_t; + pub fn cmsg_data(cmsg: *const libc::c_void) -> *const libc::c_void; + } +} + +pub use self::raw::so_passcred as SO_PASSCRED; + +pub use self::raw::scm_credentials as SCM_CREDENTIALS; +pub use self::raw::scm_rights as SCM_RIGHTS; + +use self::raw::msg_eor as MSG_EOR; +use self::raw::msg_trunc as MSG_TRUNC; +use self::raw::msg_ctrunc as MSG_CTRUNC; +use self::raw::msg_dontwait as MSG_DONTWAIT; +use self::raw::msg_cmsg_cloexec as MSG_CMSG_CLOEXEC; +use self::raw::msg_nosignal as MSG_NOSIGNAL; +use self::raw::msg_peek as MSG_PEEK; +use self::raw::msg_waitall as MSG_WAITALL; + +pub unsafe fn sendmsg( + socket: libc::c_int, + dst: Option<(libc::sockaddr_un, libc::socklen_t)>, + buffers: &[&[u8]], + ctrl_msgs: &[ControlMsg], + flags: SendMsgFlags) -> io::Result { + + let mut msg: MsgHdr = mem::zeroed(); + + // Initialize destination field + if let Some((addr, len)) = dst { + msg.msg_name = (&addr as *const libc::sockaddr_un) as *const libc::c_void; + msg.msg_namelen = len; + } + + // Initialize scatter/gather vector + let mut iovecs = Vec::with_capacity(buffers.len()); + for buf in buffers { + iovecs.push(IoVec::new(buf)); + } + msg.msg_iov = iovecs.as_mut_ptr() as *mut libc::c_void; + msg.msg_iovlen = iovecs.len() as libc::size_t; + + // Initialize control message struct + + let mut total_space: usize = 0; + for ctrl_msg in ctrl_msgs.iter().cloned() { + let size = match ctrl_msg { + ControlMsg::Rights(fds) => (mem::size_of::() * fds.len()) as libc::size_t, + ControlMsg::Credentials(..) => mem::size_of::() as libc::size_t, + _ => unimplemented!(), + }; + total_space += raw::cmsg_space(size) as usize; + } + + let mut ctrl_buf = &mut Vec::::with_capacity(total_space)[..]; + msg.msg_control = ctrl_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = total_space as libc::size_t; + + let msg_addr = (&msg as *const MsgHdr) as *const libc::c_void; + let mut cur_cmsg = raw::cmsg_firsthdr(msg_addr); + for ctrl_msg in ctrl_msgs.iter().cloned() { + if cur_cmsg == ptr::null() { + panic!("programming error: buffer too small"); + } + + let cmsg = cur_cmsg as *mut CmsgHdr; + match ctrl_msg { + // NOTE: Add handlers for new messages here + ControlMsg::Rights(fds) => { + (*cmsg).cmsg_len = raw::cmsg_len((mem::size_of::() * fds.len()) as libc::size_t) as libc::size_t; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = SCM_RIGHTS; + let data = raw::cmsg_data(cur_cmsg) as *mut libc::c_int; + ptr::copy_nonoverlapping(fds.as_ptr(), data, fds.len()); + }, + ControlMsg::Credentials(ucred) => { + (*cmsg).cmsg_len = raw::cmsg_len(mem::size_of::() as libc::size_t) as libc::size_t; + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = SCM_CREDENTIALS; + let data = raw::cmsg_data(cur_cmsg) as *mut UCred; + ptr::write(data, ucred); + } + _ => unreachable!(), + } + + cur_cmsg = raw::cmsg_nxthdr(msg_addr, cur_cmsg); + } + + let res = raw::sendmsg(socket, msg_addr, flags.as_cint()); + if res < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(res as usize) + } +} + +pub struct InternalRecvMsgResult { + pub data_bytes: usize, + pub control_msgs: Vec, + pub flags: RecvMsgResultFlags, +} + +pub unsafe fn recvmsg( + socket: libc::c_int, + buffers: &[&mut [u8]], + cmsg_buffer: &mut [u8], + flags: RecvMsgFlags, + sender_addr: *mut libc::sockaddr, + sender_len: *mut libc::socklen_t) -> io::Result { + + let mut msg: MsgHdr = mem::zeroed(); + + msg.msg_name = sender_addr as *const libc::c_void; + msg.msg_namelen = *sender_len; + + // Initialize scatter/gather vector + let mut iovecs = Vec::with_capacity(buffers.len()); + for buf in buffers { + iovecs.push(IoVec::new(buf)); + } + msg.msg_iov = iovecs.as_mut_ptr() as *mut libc::c_void; + msg.msg_iovlen = (mem::size_of::() * iovecs.len()) as libc::size_t; + + // Initialize control message struct + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = cmsg_buffer.len() as libc::size_t; + + let msg_addr = (&mut msg as *mut MsgHdr) as *mut libc::c_void; + let recvmsg_res = raw::recvmsg(socket, msg_addr, flags.as_cint()); + if recvmsg_res < 0 { + return Err(io::Error::last_os_error()); + } + + let mut cmsgs = vec![]; + + let mut cur_cmsg = raw::cmsg_firsthdr(msg_addr); + while cur_cmsg != ptr::null() { + // NOTE: Add handlers for new messages here + let cmsg = cur_cmsg as *mut CmsgHdr; + if (*cmsg).cmsg_level == libc::SOL_SOCKET { + if (*cmsg).cmsg_type == SCM_CREDENTIALS { + let ucred = raw::cmsg_data(cur_cmsg) as *mut UCred; + assert_eq!((ucred as i64) + mem::size_of::() as i64 - cur_cmsg as i64, (*cmsg).cmsg_len as i64); + cmsgs.push(ControlMsg::Credentials((*ucred).clone())); + } else if (*cmsg).cmsg_type == SCM_RIGHTS { + let mut fds = vec![]; + let data = raw::cmsg_data(cur_cmsg) as *mut libc::c_int; + let length = ((*cmsg).cmsg_len as i64 - (data as i64 - cur_cmsg as i64)) as usize; + assert_eq!(length % mem::size_of::(), 0); + let passed_fds = slice::from_raw_parts(data, length / mem::size_of::()); + for &fd in passed_fds { + fds.push(fd); + } + cmsgs.push(ControlMsg::Rights(fds)); + } else { + cmsgs.push(ControlMsg::Unknown{ level: (*cmsg).cmsg_level, typ: (*cmsg).cmsg_type }); + } + } else { + cmsgs.push(ControlMsg::Unknown{ level: (*cmsg).cmsg_level, typ: (*cmsg).cmsg_type }); + } + + cur_cmsg = raw::cmsg_nxthdr(msg_addr, cur_cmsg); + } + + + *sender_len = msg.msg_namelen; + Ok(InternalRecvMsgResult { + data_bytes: recvmsg_res as usize, + control_msgs: cmsgs, + flags: RecvMsgResultFlags::from_cint(msg.msg_flags), + }) +} + +#[derive(Clone, Copy, Debug)] +/// Flags given to sendmsg. See sendmsg(2) for more details. +pub struct SendMsgFlags { + dont_wait: bool, + end_of_record: bool, + no_signal: bool, +} + +#[derive(Clone, Copy, Debug)] +/// Flags given to recvmsg. See recvmsg(2) for more details. +pub struct RecvMsgFlags { + cmsg_cloexec: bool, + dont_wait: bool, + peek: bool, + wait_all: bool, + // TODO: Add support for MSG_ERRQUEUE (need to support more cmsgs) +} + +#[derive(Clone, Copy, Debug)] +/// Flags returned by recvmsg. See recvmsg(2) for more details. +pub struct RecvMsgResultFlags { + end_of_record: bool, + truncated: bool, + control_truncated: bool, +} + +impl SendMsgFlags { + /// Create a default SendMsgFlags + pub fn new() -> SendMsgFlags { + SendMsgFlags { + dont_wait: false, + end_of_record: false, + no_signal: false, + } + } + + /// Do not block (MSG_DONTWAIT) + pub fn dont_wait(mut self, v: bool) -> SendMsgFlags { + self.dont_wait = v; + self + } + + /// Mark this packet as the end of a record (used for SOCK_SEQPACKET connections) (MSG_EOR) + pub fn end_of_record(mut self, v: bool) -> SendMsgFlags { + self.end_of_record = v; + self + } + + /// Do not receive SIGPIPE if the other end breaks the connection (MSG_NOSIGNAL) + pub fn no_signal(mut self, v: bool) -> SendMsgFlags { + self.no_signal = v; + self + } + + fn as_cint(&self) -> libc::c_int { + let mut result = 0; + if self.dont_wait { result |= MSG_DONTWAIT; } + if self.end_of_record { result |= MSG_EOR; } + if self.no_signal { result |= MSG_NOSIGNAL; } + result + } +} + +impl RecvMsgFlags { + /// Create a default RecvMsgFlags + pub fn new() -> RecvMsgFlags { + RecvMsgFlags { + cmsg_cloexec: false, + dont_wait: false, + peek: false, + wait_all: false, + } + } + + /// Sets the close-on-exec flag for any file descriptors received via SCM_RIGHTS (MSG_CMSG_CLOEXEC) + pub fn cmsg_cloexec(mut self, v: bool) -> RecvMsgFlags { + self.cmsg_cloexec = v; + self + } + + /// Do not block (MSG_DONTWAIT) + pub fn dont_wait(mut self, v: bool) -> RecvMsgFlags { + self.dont_wait = v; + self + } + + /// Do not remove the retrieved data from the receive queue (the next call will return the same data) (MSG_PEEK) + pub fn peek(mut self, v: bool) -> RecvMsgFlags { + self.peek = v; + self + } + + /// Wait for the buffers to be filled (may still be interrupted by a signal or the socket hanging up) (MSG_WAITALL) + pub fn wait_all(mut self, v: bool) -> RecvMsgFlags { + self.wait_all = v; + self + } + + fn as_cint(&self) -> libc::c_int { + let mut result = 0; + if self.cmsg_cloexec { result |= MSG_CMSG_CLOEXEC; } + if self.dont_wait { result |= MSG_DONTWAIT; } + if self.peek { result |= MSG_PEEK; } + if self.wait_all { result |= MSG_WAITALL; } + result + } +} + +impl RecvMsgResultFlags { + /// The returned data marks the end of a record (used for SOCK_SEQPACKET) (MSG_EOR) + pub fn end_of_record(&self) -> bool { + self.end_of_record + } + + /// Some data was discarded due to the provided buffers being too short (MSG_TRUNC) + pub fn truncated(&self) -> bool { + self.truncated + } + + /// Some control data was discarded (MSG_CTRUNC) + pub fn control_truncated(&self) -> bool { + self.control_truncated + } + + fn from_cint(flags: libc::c_int) -> RecvMsgResultFlags { + RecvMsgResultFlags { + end_of_record: (flags & MSG_EOR) != 0, + truncated: (flags & MSG_TRUNC) != 0, + control_truncated: (flags & MSG_CTRUNC) != 0, + } + } +} + +#[repr(C)] +struct MsgHdr { + pub msg_name: *const libc::c_void, + pub msg_namelen: libc::socklen_t, + pub msg_iov: *mut libc::c_void, + pub msg_iovlen: libc::size_t, + pub msg_control: *mut libc::c_void, + pub msg_controllen: libc::size_t, + pub msg_flags: libc::c_int, +} + +#[test] +fn msghdr_size_correctness() { + assert_eq!(raw::msghdr_size as usize, mem::size_of::()); +} + +#[repr(C)] +struct IoVec { + base: *const libc::c_void, + len: libc::size_t, +} + +impl IoVec { + fn new(buf: &[u8]) -> IoVec { + IoVec { + base: buf.as_ptr() as *const libc::c_void, + len: buf.len() as libc::size_t, + } + } +} + +#[test] +fn iovec_size_correctness() { + assert_eq!(raw::iovec_size as usize, mem::size_of::()); +} + +#[repr(C)] +struct CmsgHdr { + cmsg_len: libc::size_t, + cmsg_level: libc::c_int, + cmsg_type: libc::c_int, +} + +#[test] +fn cmsghdr_size_correctness() { + assert_eq!(raw::cmsghdr_size as usize, mem::size_of::()); +} + +/// Unix credential that can be sent/received over Unix sockets using `ControlMsg::Credential` +/// +/// This is a Rust version of `struct ucred` from sys/socket.h +#[derive(Clone, Debug)] +pub struct UCred{ + /// The sender's process id + pub pid: libc::pid_t, + /// The sender's user id + pub uid: libc::uid_t, + /// The sender's group id + pub gid: libc::gid_t, +} + +#[test] +fn ucred_size_correctness() { + assert_eq!(raw::ucred_size as usize, mem::size_of::()); +} + +/// Ancillary messages that can be sent/received over Unix sockets using `sendmsg`/`recvmsg`. +#[derive(Clone, Debug)] +pub enum ControlMsg { + /// Message used to transfer file descriptors + Rights(Vec), + /// Message used to provide kernel-verified Unix credentials of the sender + Credentials(UCred), + /// Any unimplemented message + Unknown { + /// cmsg_level of the unimplemented message + level: libc::c_int, + /// cmsg_type of the unimplemented message + typ: libc::c_int, + }, + // To add support for more messages, define the message in ControlMsg, + // and near the relevant NOTE comments above. +}