Skip to content

Commit 4c5f6ce

Browse files
committed
Add support for mbind, get_mempolicy and set_mempolicy (#937)
This adds support for the `mbind`, `set_mempolicy` and `get_mempolicy` NUMA syscalls. The `get_mempolicy` syscall has a few different modes of operation, depending on the flags, which is demultiplexed into `get_mempolicy_node` and `get_mempolicy_next_node` for now. There's a couple of other modes that writes into the variable length bit array, which aren't implemented for now.
1 parent 3056dec commit 4c5f6ce

File tree

9 files changed

+321
-1
lines changed

9 files changed

+321
-1
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ once_cell = { version = "1.5.2", optional = true }
3636
# libc backend can be selected via adding `--cfg=rustix_use_libc` to
3737
# `RUSTFLAGS` or enabling the `use-libc` cargo feature.
3838
[target.'cfg(all(not(rustix_use_libc), not(miri), target_os = "linux", target_endian = "little", any(target_arch = "arm", all(target_arch = "aarch64", target_pointer_width = "64"), target_arch = "riscv64", all(rustix_use_experimental_asm, target_arch = "powerpc64"), all(rustix_use_experimental_asm, target_arch = "mips"), all(rustix_use_experimental_asm, target_arch = "mips32r6"), all(rustix_use_experimental_asm, target_arch = "mips64"), all(rustix_use_experimental_asm, target_arch = "mips64r6"), target_arch = "x86", all(target_arch = "x86_64", target_pointer_width = "64"))))'.dependencies]
39-
linux-raw-sys = { version = "0.4.11", default-features = false, features = ["general", "errno", "ioctl", "no_std", "elf"] }
39+
linux-raw-sys = { version = "0.6.2", default-features = false, features = ["general", "errno", "ioctl", "mempolicy", "no_std", "elf"] }
4040
libc_errno = { package = "errno", version = "0.3.8", default-features = false, optional = true }
4141
libc = { version = "0.2.150", default-features = false, features = ["extra_traits"], optional = true }
4242

@@ -170,6 +170,9 @@ termios = []
170170
# Enable `rustix::mm::*`.
171171
mm = []
172172

173+
# Enable `rustix::numa::*`.
174+
numa = []
175+
173176
# Enable `rustix::pipe::*`.
174177
pipe = []
175178

@@ -194,6 +197,7 @@ all-apis = [
194197
"mm",
195198
"mount",
196199
"net",
200+
"numa",
197201
"param",
198202
"pipe",
199203
"process",

src/backend/linux_raw/conv.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From<Option<crate::net::Protocol>> for ArgReg<'a, Num>
818818
}
819819
}
820820

821+
#[cfg(feature = "numa")]
822+
impl<'a, Num: ArgNumber> From<crate::numa::Mode> for ArgReg<'a, Num> {
823+
#[inline]
824+
fn from(flags: crate::numa::Mode) -> Self {
825+
c_uint(flags.bits())
826+
}
827+
}
828+
829+
#[cfg(feature = "numa")]
830+
impl<'a, Num: ArgNumber> From<crate::numa::ModeFlags> for ArgReg<'a, Num> {
831+
#[inline]
832+
fn from(flags: crate::numa::ModeFlags) -> Self {
833+
c_uint(flags.bits())
834+
}
835+
}
836+
821837
impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit<T>> for ArgReg<'a, Num> {
822838
#[inline]
823839
fn from(t: &'a mut MaybeUninit<T>) -> Self {

src/backend/linux_raw/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub(crate) mod mount;
5151
pub(crate) mod mount; // for deprecated mount functions in "fs"
5252
#[cfg(feature = "net")]
5353
pub(crate) mod net;
54+
#[cfg(feature = "numa")]
55+
pub(crate) mod numa;
5456
#[cfg(any(
5557
feature = "param",
5658
feature = "process",

src/backend/linux_raw/numa/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub(crate) mod syscalls;
2+
pub(crate) mod types;
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
//! linux_raw syscalls supporting `rustix::numa`.
2+
//!
3+
//! # Safety
4+
//!
5+
//! See the `rustix::backend` module documentation for details.
6+
7+
#![allow(unsafe_code)]
8+
#![allow(clippy::undocumented_unsafe_blocks)]
9+
10+
use super::types::{Mode, ModeFlags};
11+
12+
use crate::backend::c;
13+
use crate::backend::conv::{c_uint, pass_usize, ret, zero};
14+
use crate::io;
15+
use core::mem::MaybeUninit;
16+
17+
/// # Safety
18+
///
19+
/// `mbind` is primarily unsafe due to the `addr` parameter, as anything
20+
/// working with memory pointed to by raw pointers is unsafe.
21+
#[inline]
22+
pub(crate) unsafe fn mbind(
23+
addr: *mut c::c_void,
24+
length: usize,
25+
mode: Mode,
26+
nodemask: &[u64],
27+
flags: ModeFlags,
28+
) -> io::Result<()> {
29+
ret(syscall!(
30+
__NR_mbind,
31+
addr,
32+
pass_usize(length),
33+
mode,
34+
nodemask.as_ptr(),
35+
pass_usize(nodemask.len() * u64::BITS as usize),
36+
flags
37+
))
38+
}
39+
40+
/// # Safety
41+
///
42+
/// `set_mempolicy` is primarily unsafe due to the `addr` parameter,
43+
/// as anything working with memory pointed to by raw pointers is
44+
/// unsafe.
45+
#[inline]
46+
pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
47+
ret(syscall!(
48+
__NR_set_mempolicy,
49+
mode,
50+
nodemask.as_ptr(),
51+
pass_usize(nodemask.len() * u64::BITS as usize)
52+
))
53+
}
54+
55+
/// # Safety
56+
///
57+
/// `get_mempolicy` is primarily unsafe due to the `addr` parameter,
58+
/// as anything working with memory pointed to by raw pointers is
59+
/// unsafe.
60+
#[inline]
61+
pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result<usize> {
62+
let mut mode = MaybeUninit::<usize>::uninit();
63+
64+
ret(syscall!(
65+
__NR_get_mempolicy,
66+
&mut mode,
67+
zero(),
68+
zero(),
69+
addr,
70+
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE | linux_raw_sys::mempolicy::MPOL_F_ADDR)
71+
))?;
72+
73+
Ok(mode.assume_init())
74+
}
75+
76+
#[inline]
77+
pub(crate) fn get_mempolicy_next_node() -> io::Result<usize> {
78+
let mut mode = MaybeUninit::<usize>::uninit();
79+
80+
unsafe {
81+
ret(syscall!(
82+
__NR_get_mempolicy,
83+
&mut mode,
84+
zero(),
85+
zero(),
86+
zero(),
87+
c_uint(linux_raw_sys::mempolicy::MPOL_F_NODE)
88+
))?;
89+
90+
Ok(mode.assume_init())
91+
}
92+
}

src/backend/linux_raw/numa/types.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use bitflags::bitflags;
2+
3+
bitflags! {
4+
/// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`].
5+
///
6+
/// [`mbind`]: crate::io::mbind
7+
#[repr(transparent)]
8+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
9+
pub struct Mode: u32 {
10+
/// `MPOL_F_STATIC_NODES`
11+
const STATIC_NODES = linux_raw_sys::mempolicy::MPOL_F_STATIC_NODES;
12+
/// `MPOL_F_RELATIVE_NODES`
13+
const RELATIVE_NODES = linux_raw_sys::mempolicy::MPOL_F_RELATIVE_NODES;
14+
/// `MPOL_F_NUMA_BALANCING`
15+
const NUMA_BALANCING = linux_raw_sys::mempolicy::MPOL_F_NUMA_BALANCING;
16+
17+
/// `MPOL_DEFAULT`
18+
const DEFAULT = linux_raw_sys::mempolicy::MPOL_DEFAULT as u32;
19+
/// `MPOL_PREFERRED`
20+
const PREFERRED = linux_raw_sys::mempolicy::MPOL_PREFERRED as u32;
21+
/// `MPOL_BIND`
22+
const BIND = linux_raw_sys::mempolicy::MPOL_BIND as u32;
23+
/// `MPOL_INTERLEAVE`
24+
const INTERLEAVE = linux_raw_sys::mempolicy::MPOL_INTERLEAVE as u32;
25+
/// `MPOL_LOCAL`
26+
const LOCAL = linux_raw_sys::mempolicy::MPOL_LOCAL as u32;
27+
/// `MPOL_PREFERRED_MANY`
28+
const PREFERRED_MANY = linux_raw_sys::mempolicy::MPOL_PREFERRED_MANY as u32;
29+
30+
/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
31+
const _ = !0;
32+
}
33+
}
34+
35+
bitflags! {
36+
/// `MPOL_MF_*` flags for use with [`mbind`].
37+
///
38+
/// [`mbind`]: crate::io::mbind
39+
#[repr(transparent)]
40+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41+
pub struct ModeFlags: u32 {
42+
/// `MPOL_MF_STRICT`
43+
const STRICT = linux_raw_sys::mempolicy::MPOL_MF_STRICT;
44+
/// `MPOL_MF_MOVE`
45+
const MOVE = linux_raw_sys::mempolicy::MPOL_MF_MOVE;
46+
/// `MPOL_MF_MOVE_ALL`
47+
const MOVE_ALL = linux_raw_sys::mempolicy::MPOL_MF_MOVE_ALL;
48+
49+
/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
50+
const _ = !0;
51+
}
52+
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ pub mod mount;
218218
#[cfg(feature = "net")]
219219
#[cfg_attr(doc_cfg, doc(cfg(feature = "net")))]
220220
pub mod net;
221+
#[cfg(linux_kernel)]
222+
#[cfg(feature = "numa")]
223+
#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))]
224+
pub mod numa;
221225
#[cfg(not(any(windows, target_os = "espidf")))]
222226
#[cfg(feature = "param")]
223227
#[cfg_attr(doc_cfg, doc(cfg(feature = "param")))]

src/numa/mod.rs

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
//! The `numa` API.
2+
//!
3+
//! # Safety
4+
//!
5+
//! `mbind` and related functions manipulate raw pointers and have special
6+
//! semantics and are wildly unsafe.
7+
#![allow(unsafe_code)]
8+
9+
use crate::{backend, io};
10+
use core::ffi::c_void;
11+
12+
pub use backend::numa::types::{Mode, ModeFlags};
13+
14+
/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range.
15+
///
16+
/// # Safety
17+
///
18+
/// This function operates on raw pointers, but it should only be used
19+
/// on memory which the caller owns.
20+
///
21+
/// # References
22+
/// - [Linux]
23+
///
24+
/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html
25+
#[cfg(linux_kernel)]
26+
#[inline]
27+
pub unsafe fn mbind(
28+
addr: *mut c_void,
29+
len: usize,
30+
mode: Mode,
31+
nodemask: &[u64],
32+
flags: ModeFlags,
33+
) -> io::Result<()> {
34+
backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags)
35+
}
36+
37+
/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for
38+
/// a thread and its children.
39+
///
40+
/// # Safety
41+
///
42+
/// This function operates on raw pointers, but it should only be used
43+
/// on memory which the caller owns.
44+
///
45+
/// # References
46+
/// - [Linux]
47+
///
48+
/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
49+
#[cfg(linux_kernel)]
50+
#[inline]
51+
pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
52+
backend::numa::syscalls::set_mempolicy(mode, nodemask)
53+
}
54+
55+
/// `get_mempolicy_node(addr)`-Return the node ID of the node on which
56+
/// the address addr is allocated.
57+
///
58+
/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
59+
/// get_mempolicy() will return the node ID of the node on which the
60+
/// address addr is allocated into the location pointed to by mode.
61+
/// If no page has yet been allocated for the specified address,
62+
/// get_mempolicy() will allocate a page as if the thread had
63+
/// performed a read (load) access to that address, and return the ID
64+
/// of the node where that page was allocated.
65+
///
66+
/// # Safety
67+
///
68+
/// This function operates on raw pointers, but it should only be used
69+
/// on memory which the caller owns.
70+
///
71+
/// # References
72+
/// - [Linux]
73+
///
74+
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
75+
#[cfg(linux_kernel)]
76+
#[inline]
77+
pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result<usize> {
78+
backend::numa::syscalls::get_mempolicy_node(addr)
79+
}
80+
81+
/// `get_mempolicy_next_node(addr)`-Return node ID of the next node
82+
/// that will be used for interleaving of internal kernel pages
83+
/// allocated on behalf of the thread.
84+
///
85+
/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the
86+
/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy()
87+
/// will return in the location pointed to by a non-NULL mode
88+
/// argument, the node ID of the next node that will be used for
89+
/// interleaving of internal kernel pages allocated on behalf of the
90+
/// thread. These allocations include pages for memory-mapped files
91+
/// in process memory ranges mapped using the mmap(2) call with the
92+
/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped
93+
/// with the MAP_SHARED flag for all accesses.
94+
///
95+
/// # Safety
96+
///
97+
/// This function operates on raw pointers, but it should only be used
98+
/// on memory which the caller owns.
99+
///
100+
/// # References
101+
/// - [Linux]
102+
///
103+
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
104+
#[cfg(linux_kernel)]
105+
#[inline]
106+
pub unsafe fn get_mempolicy_next_node() -> io::Result<usize> {
107+
backend::numa::syscalls::get_mempolicy_next_node()
108+
}

tests/numa/main.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#[cfg(all(feature = "mm", feature = "fs"))]
2+
#[test]
3+
fn test_mbind() {
4+
let size = 8192;
5+
6+
unsafe {
7+
let vaddr = rustix::mm::mmap_anonymous(
8+
std::ptr::null_mut(),
9+
size,
10+
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
11+
rustix::mm::MapFlags::PRIVATE,
12+
)
13+
.unwrap();
14+
15+
vaddr.cast::<usize>().write(100);
16+
17+
let mask = &[1];
18+
rustix::numa::mbind(
19+
vaddr,
20+
size,
21+
rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
22+
mask,
23+
rustix::numa::ModeFlags::empty(),
24+
)
25+
.unwrap();
26+
27+
rustix::numa::get_mempolicy_node(vaddr).unwrap();
28+
29+
match rustix::numa::get_mempolicy_next_node() {
30+
Err(rustix::io::Errno::INVAL) => (),
31+
_ => panic!(
32+
"rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT"
33+
),
34+
}
35+
36+
rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap();
37+
38+
rustix::numa::get_mempolicy_next_node().unwrap();
39+
}
40+
}

0 commit comments

Comments
 (0)