Skip to content

WIP: Add support for mbind, get_mempolicy and set_mempolicy #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
linux-raw-sys = { version = "0.4.11", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "errno", "ioctl", "mempolicy", "no_std", "elf"] }
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
libc = { version = "0.2.150", default-features = false, features = ["extra_traits"], optional = true }

Expand All @@ -53,7 +53,7 @@ libc = { version = "0.2.150", default-features = false, features = ["extra_trait
# Some syscalls do not have libc wrappers, such as in `io_uring`. For these,
# the libc backend uses the linux-raw-sys ABI and `libc::syscall`.
[target.'cfg(all(any(target_os = "android", target_os = "linux"), any(rustix_use_libc, miri, not(all(target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64")))))))'.dependencies]
linux-raw-sys = { version = "0.4.11", default-features = false, features = ["general", "ioctl", "no_std"] }
linux-raw-sys = { version = "0.4.12", default-features = false, features = ["general", "ioctl", "no_std"] }

# For the libc backend on Windows, use the Winsock2 API in windows-sys.
[target.'cfg(windows)'.dependencies.windows-sys]
Expand Down Expand Up @@ -170,6 +170,9 @@ termios = []
# Enable `rustix::mm::*`.
mm = []

# Enable `rustix::numa::*`.
numa = []

# Enable `rustix::pipe::*`.
pipe = []

Expand All @@ -194,6 +197,7 @@ all-apis = [
"mm",
"mount",
"net",
"numa",
"param",
"pipe",
"process",
Expand Down
16 changes: 16 additions & 0 deletions src/backend/linux_raw/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From<Option<crate::net::Protocol>> for ArgReg<'a, Num>
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::Mode> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::Mode) -> Self {
c_uint(flags.bits())
}
}

#[cfg(feature = "numa")]
impl<'a, Num: ArgNumber> From<crate::numa::ModeFlags> for ArgReg<'a, Num> {
#[inline]
fn from(flags: crate::numa::ModeFlags) -> Self {
c_uint(flags.bits())
}
}

impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit<T>> for ArgReg<'a, Num> {
#[inline]
fn from(t: &'a mut MaybeUninit<T>) -> Self {
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub(crate) mod mount;
pub(crate) mod mount; // for deprecated mount functions in "fs"
#[cfg(feature = "net")]
pub(crate) mod net;
#[cfg(feature = "numa")]
pub(crate) mod numa;
#[cfg(any(
feature = "param",
feature = "process",
Expand Down
2 changes: 2 additions & 0 deletions src/backend/linux_raw/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod syscalls;
pub(crate) mod types;
92 changes: 92 additions & 0 deletions src/backend/linux_raw/numa/syscalls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
//! linux_raw syscalls supporting `rustix::numa`.
//!
//! # Safety
//!
//! See the `rustix::backend` module documentation for details.

#![allow(unsafe_code)]
#![allow(clippy::undocumented_unsafe_blocks)]

use super::types::{Mode, ModeFlags};

use crate::backend::c;
use crate::backend::conv::{c_uint, pass_usize, ret, zero};
use crate::io;
use core::mem::MaybeUninit;

/// # Safety
///
/// `mbind` is primarily unsafe due to the `addr` parameter, as anything
/// working with memory pointed to by raw pointers is unsafe.
#[inline]
pub(crate) unsafe fn mbind(
addr: *mut c::c_void,
length: usize,
mode: Mode,
nodemask: &[u64],
flags: ModeFlags,
) -> io::Result<()> {
ret(syscall!(
__NR_mbind,
addr,
pass_usize(length),
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize),
flags
))
}

/// # Safety
///
/// `set_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
ret(syscall!(
__NR_set_mempolicy,
mode,
nodemask.as_ptr(),
pass_usize(nodemask.len() * u64::BITS as usize)
))
}

/// # Safety
///
/// `get_mempolicy` is primarily unsafe due to the `addr` parameter,
/// as anything working with memory pointed to by raw pointers is
/// unsafe.
#[inline]
pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
addr,
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE | linux_raw_sys::mempolicy::MPOL_F_ADDR)
))?;

Ok(mode.assume_init())
}

#[inline]
pub(crate) fn get_mempolicy_next_node() -> io::Result<usize> {
let mut mode = MaybeUninit::<usize>::uninit();

unsafe {
ret(syscall!(
__NR_get_mempolicy,
&mut mode,
zero(),
zero(),
zero(),
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE)
))?;

Ok(mode.assume_init())
}
}
52 changes: 52 additions & 0 deletions src/backend/linux_raw/numa/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use bitflags::bitflags;

bitflags! {
/// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct Mode: u32 {
/// `MPOL_F_STATIC_NODES`
const STATIC_NODES = linux_raw_sys::mempolicy::MPOL_F_STATIC_NODES;
/// `MPOL_F_RELATIVE_NODES`
const RELATIVE_NODES = linux_raw_sys::mempolicy::MPOL_F_RELATIVE_NODES;
/// `MPOL_F_NUMA_BALANCING`
const NUMA_BALANCING = linux_raw_sys::mempolicy::MPOL_F_NUMA_BALANCING;

/// `MPOL_DEFAULT`
const DEFAULT = linux_raw_sys::mempolicy::MPOL_DEFAULT as u32;
/// `MPOL_PREFERRED`
const PREFERRED = linux_raw_sys::mempolicy::MPOL_PREFERRED as u32;
/// `MPOL_BIND`
const BIND = linux_raw_sys::mempolicy::MPOL_BIND as u32;
/// `MPOL_INTERLEAVE`
const INTERLEAVE = linux_raw_sys::mempolicy::MPOL_INTERLEAVE as u32;
/// `MPOL_LOCAL`
const LOCAL = linux_raw_sys::mempolicy::MPOL_LOCAL as u32;
/// `MPOL_PREFERRED_MANY`
const PREFERRED_MANY = linux_raw_sys::mempolicy::MPOL_PREFERRED_MANY as u32;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}

bitflags! {
/// `MPOL_MF_*` flags for use with [`mbind`].
///
/// [`mbind`]: crate::io::mbind
#[repr(transparent)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub struct ModeFlags: u32 {
/// `MPOL_MF_STRICT`
const STRICT = linux_raw_sys::mempolicy::MPOL_MF_STRICT;
/// `MPOL_MF_MOVE`
const MOVE = linux_raw_sys::mempolicy::MPOL_MF_MOVE;
/// `MPOL_MF_MOVE_ALL`
const MOVE_ALL = linux_raw_sys::mempolicy::MPOL_MF_MOVE_ALL;

/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
const _ = !0;
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ pub mod mount;
#[cfg(feature = "net")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "net")))]
pub mod net;
#[cfg(linux_kernel)]
#[cfg(feature = "numa")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))]
pub mod numa;
#[cfg(not(any(windows, target_os = "espidf")))]
#[cfg(feature = "param")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "param")))]
Expand Down
108 changes: 108 additions & 0 deletions src/numa/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
//! The `numa` API.
//!
//! # Safety
//!
//! `mbind` and related functions manipulate raw pointers and have special
//! semantics and are wildly unsafe.
#![allow(unsafe_code)]

use crate::{backend, io};
use core::ffi::c_void;

pub use backend::numa::types::{Mode, ModeFlags};

/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn mbind(
addr: *mut c_void,
len: usize,
mode: Mode,
nodemask: &[u64],
flags: ModeFlags,
) -> io::Result<()> {
backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags)
}

/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for
/// a thread and its children.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
backend::numa::syscalls::set_mempolicy(mode, nodemask)
}

/// `get_mempolicy_node(addr)`-Return the node ID of the node on which
/// the address addr is allocated.
///
/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
/// get_mempolicy() will return the node ID of the node on which the
/// address addr is allocated into the location pointed to by mode.
/// If no page has yet been allocated for the specified address,
/// get_mempolicy() will allocate a page as if the thread had
/// performed a read (load) access to that address, and return the ID
/// of the node where that page was allocated.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_node(addr)
}

/// `get_mempolicy_next_node(addr)`-Return node ID of the next node
/// that will be used for interleaving of internal kernel pages
/// allocated on behalf of the thread.
///
/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the
/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy()
/// will return in the location pointed to by a non-NULL mode
/// argument, the node ID of the next node that will be used for
/// interleaving of internal kernel pages allocated on behalf of the
/// thread. These allocations include pages for memory-mapped files
/// in process memory ranges mapped using the mmap(2) call with the
/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped
/// with the MAP_SHARED flag for all accesses.
///
/// # Safety
///
/// This function operates on raw pointers, but it should only be used
/// on memory which the caller owns.
///
/// # References
/// - [Linux]
///
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
#[cfg(linux_kernel)]
#[inline]
pub unsafe fn get_mempolicy_next_node() -> io::Result<usize> {
backend::numa::syscalls::get_mempolicy_next_node()
}
40 changes: 40 additions & 0 deletions tests/numa/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#[cfg(all(feature = "mm", feature = "fs"))]
#[test]
fn test_mbind() {
let size = 8192;

unsafe {
let vaddr = rustix::mm::mmap_anonymous(
std::ptr::null_mut(),
size,
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
rustix::mm::MapFlags::PRIVATE,
)
.unwrap();

vaddr.cast::<usize>().write(100);

let mask = &[1];
rustix::numa::mbind(
vaddr,
size,
rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
mask,
rustix::numa::ModeFlags::empty(),
)
.unwrap();

rustix::numa::get_mempolicy_node(vaddr).unwrap();

match rustix::numa::get_mempolicy_next_node() {
Err(rustix::io::Errno::INVAL) => (),
_ => panic!(
"rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT"
),
}

rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap();

rustix::numa::get_mempolicy_next_node().unwrap();
}
}