Skip to content

Commit f321b15

Browse files
committed
Add support for mbind, get_mempolicy and set_mempolicy (#937)
This adds support for the `mbind`, `set_mempolicy` and `get_mempolicy` NUMA syscalls. The `get_mempolicy` syscall has a few different modes of operation, depending on the flags, which is demultiplexed into `get_mempolicy_node` and `get_mempolicy_next_node` for now. There's a couple of other modes that writes into the variable length bit array, which aren't implemented for now.
1 parent 496792e commit f321b15

File tree

11 files changed

+377
-0
lines changed

11 files changed

+377
-0
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ termios = []
170170
# Enable `rustix::mm::*`.
171171
mm = []
172172

173+
# Enable `rustix::numa::*`.
174+
numa = []
175+
173176
# Enable `rustix::pipe::*`.
174177
pipe = []
175178

@@ -194,6 +197,7 @@ all-apis = [
194197
"mm",
195198
"mount",
196199
"net",
200+
"numa",
197201
"param",
198202
"pipe",
199203
"process",

src/backend/linux_raw/conv.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,22 @@ impl<'a, Num: ArgNumber> From<Option<crate::net::Protocol>> for ArgReg<'a, Num>
818818
}
819819
}
820820

821+
#[cfg(feature = "numa")]
822+
impl<'a, Num: ArgNumber> From<crate::numa::Mode> for ArgReg<'a, Num> {
823+
#[inline]
824+
fn from(flags: crate::numa::Mode) -> Self {
825+
c_uint(flags.bits())
826+
}
827+
}
828+
829+
#[cfg(feature = "numa")]
830+
impl<'a, Num: ArgNumber> From<crate::numa::ModeFlags> for ArgReg<'a, Num> {
831+
#[inline]
832+
fn from(flags: crate::numa::ModeFlags) -> Self {
833+
c_uint(flags.bits())
834+
}
835+
}
836+
821837
impl<'a, Num: ArgNumber, T> From<&'a mut MaybeUninit<T>> for ArgReg<'a, Num> {
822838
#[inline]
823839
fn from(t: &'a mut MaybeUninit<T>) -> Self {

src/backend/linux_raw/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ pub(crate) mod mount;
5151
pub(crate) mod mount; // for deprecated mount functions in "fs"
5252
#[cfg(feature = "net")]
5353
pub(crate) mod net;
54+
#[cfg(feature = "numa")]
55+
pub(crate) mod numa;
5456
#[cfg(any(
5557
feature = "param",
5658
feature = "process",

src/backend/linux_raw/numa/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub(crate) mod syscalls;
2+
pub(crate) mod types;
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
//! linux_raw syscalls supporting `rustix::numa`.
2+
//!
3+
//! # Safety
4+
//!
5+
//! See the `rustix::backend` module documentation for details.
6+
7+
#![allow(unsafe_code)]
8+
#![allow(clippy::undocumented_unsafe_blocks)]
9+
10+
use super::types::{Mode, ModeFlags};
11+
12+
use core::ptr::null_mut;
13+
use core::mem::MaybeUninit;
14+
use crate::backend::c;
15+
use crate::backend::conv::{c_uint, no_fd, pass_usize, ret, ret_owned_fd, ret_void_star, zero};
16+
use crate::io;
17+
18+
/// # Safety
19+
///
20+
/// `mbind` is primarily unsafe due to the `addr` parameter, as anything
21+
/// working with memory pointed to by raw pointers is unsafe.
22+
#[inline]
23+
pub(crate) unsafe fn mbind(addr: *mut c::c_void, length: usize, mode: Mode, nodemask: &[u64], flags: ModeFlags) -> io::Result<()> {
24+
ret(syscall!(
25+
__NR_mbind,
26+
addr,
27+
pass_usize(length),
28+
mode,
29+
nodemask.as_ptr(),
30+
pass_usize(nodemask.len() * u64::BITS as usize),
31+
flags
32+
))
33+
}
34+
35+
/// # Safety
36+
///
37+
/// `set_mempolicy` is primarily unsafe due to the `addr` parameter,
38+
/// as anything working with memory pointed to by raw pointers is
39+
/// unsafe.
40+
#[inline]
41+
pub(crate) unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
42+
ret(syscall!(
43+
__NR_set_mempolicy,
44+
mode,
45+
nodemask.as_ptr(),
46+
pass_usize(nodemask.len() * u64::BITS as usize)
47+
))
48+
}
49+
50+
/// # Safety
51+
///
52+
/// `get_mempolicy` is primarily unsafe due to the `addr` parameter,
53+
/// as anything working with memory pointed to by raw pointers is
54+
/// unsafe.
55+
#[inline]
56+
pub(crate) unsafe fn get_mempolicy_node(addr: *mut c::c_void) -> io::Result<usize> {
57+
let mut mode = MaybeUninit::<usize>::uninit();
58+
59+
ret(syscall!(
60+
__NR_get_mempolicy,
61+
&mut mode,
62+
zero(),
63+
zero(),
64+
addr,
65+
c_uint(linux_raw_sys::general::MPOL_F_NODE | linux_raw_sys::general::MPOL_F_ADDR)
66+
))?;
67+
68+
Ok(mode.assume_init())
69+
}
70+
71+
#[inline]
72+
pub(crate) fn get_mempolicy_next_node() -> io::Result<usize> {
73+
let mut mode = MaybeUninit::<usize>::uninit();
74+
75+
unsafe {
76+
ret(syscall!(
77+
__NR_get_mempolicy,
78+
&mut mode,
79+
zero(),
80+
zero(),
81+
zero(),
82+
c_uint(linux_raw_sys::general::MPOL_F_NODE)
83+
))?;
84+
85+
Ok(mode.assume_init())
86+
}
87+
}

src/backend/linux_raw/numa/types.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
use bitflags::bitflags;
2+
3+
bitflags! {
4+
/// `MPOL_*` and `MPOL_F_*` flags for use with [`mbind`].
5+
///
6+
/// [`mbind`]: crate::io::mbind
7+
#[repr(transparent)]
8+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
9+
pub struct Mode: u32 {
10+
/// `MPOL_F_STATIC_NODES`
11+
const STATIC_NODES = linux_raw_sys::general::MPOL_F_STATIC_NODES;
12+
/// `MPOL_F_RELATIVE_NODES`
13+
const RELATIVE_NODES = linux_raw_sys::general::MPOL_F_RELATIVE_NODES;
14+
/// `MPOL_F_NUMA_BALANCING`
15+
const NUMA_BALANCING = linux_raw_sys::general::MPOL_F_NUMA_BALANCING;
16+
17+
/// `MPOL_DEFAULT`
18+
const DEFAULT = linux_raw_sys::general::MPOL_DEFAULT as u32;
19+
/// `MPOL_PREFERRED`
20+
const PREFERRED = linux_raw_sys::general::MPOL_PREFERRED as u32;
21+
/// `MPOL_BIND`
22+
const BIND = linux_raw_sys::general::MPOL_BIND as u32;
23+
/// `MPOL_INTERLEAVE`
24+
const INTERLEAVE = linux_raw_sys::general::MPOL_INTERLEAVE as u32;
25+
/// `MPOL_LOCAL`
26+
const LOCAL = linux_raw_sys::general::MPOL_LOCAL as u32;
27+
/// `MPOL_PREFERRED_MANY`
28+
const PREFERRED_MANY = linux_raw_sys::general::MPOL_PREFERRED_MANY as u32;
29+
30+
/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
31+
const _ = !0;
32+
}
33+
}
34+
35+
bitflags! {
36+
/// `MPOL_MF_*` flags for use with [`mbind`].
37+
///
38+
/// [`mbind`]: crate::io::mbind
39+
#[repr(transparent)]
40+
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
41+
pub struct ModeFlags: u32 {
42+
/// `MPOL_MF_STRICT`
43+
const STRICT = linux_raw_sys::general::MPOL_MF_STRICT;
44+
/// `MPOL_MF_MOVE`
45+
const MOVE = linux_raw_sys::general::MPOL_MF_MOVE;
46+
/// `MPOL_MF_MOVE_ALL`
47+
const MOVE_ALL = linux_raw_sys::general::MPOL_MF_MOVE_ALL;
48+
49+
/// <https://docs.rs/bitflags/*/bitflags/#externally-defined-flags>
50+
const _ = !0;
51+
}
52+
}

src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ pub mod mount;
218218
#[cfg(feature = "net")]
219219
#[cfg_attr(doc_cfg, doc(cfg(feature = "net")))]
220220
pub mod net;
221+
#[cfg(linux_kernel)]
222+
#[cfg(feature = "numa")]
223+
#[cfg_attr(doc_cfg, doc(cfg(feature = "numa")))]
224+
pub mod numa;
221225
#[cfg(not(any(windows, target_os = "espidf")))]
222226
#[cfg(feature = "param")]
223227
#[cfg_attr(doc_cfg, doc(cfg(feature = "param")))]

src/numa/mod.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//! The `numa` API.
2+
//!
3+
//! # Safety
4+
//!
5+
//! `mbind` and related functions manipulate raw pointers and have special
6+
//! semantics and are wildly unsafe.
7+
#![allow(unsafe_code)]
8+
9+
use crate::{backend, io};
10+
use core::ffi::c_void;
11+
12+
pub use backend::numa::types::{Mode, ModeFlags};
13+
14+
/// `mbind(addr, len, mode, nodemask)`-Set memory policy for a memory range.
15+
///
16+
/// # Safety
17+
///
18+
/// This function operates on raw pointers, but it should only be used
19+
/// on memory which the caller owns.
20+
///
21+
/// # References
22+
/// - [Linux]
23+
///
24+
/// [Linux]: https://man7.org/linux/man-pages/man2/mbind.2.html
25+
#[cfg(linux_kernel)]
26+
#[inline]
27+
pub unsafe fn mbind(addr: *mut c_void, len: usize, mode: Mode, nodemask: &[u64], flags: ModeFlags) -> io::Result<()> {
28+
backend::numa::syscalls::mbind(addr, len, mode, nodemask, flags)
29+
}
30+
31+
32+
/// `set_mempolicy(mode, nodemask)`-Set default NUMA memory policy for
33+
/// a thread and its children.
34+
///
35+
/// # Safety
36+
///
37+
/// This function operates on raw pointers, but it should only be used
38+
/// on memory which the caller owns.
39+
///
40+
/// # References
41+
/// - [Linux]
42+
///
43+
/// [Linux]: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
44+
#[cfg(linux_kernel)]
45+
#[inline]
46+
pub unsafe fn set_mempolicy(mode: Mode, nodemask: &[u64]) -> io::Result<()> {
47+
backend::numa::syscalls::set_mempolicy(mode, nodemask)
48+
}
49+
50+
/// `get_mempolicy_node(addr)`-Return the node ID of the node on which
51+
/// the address addr is allocated.
52+
///
53+
/// If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
54+
/// get_mempolicy() will return the node ID of the node on which the
55+
/// address addr is allocated into the location pointed to by mode.
56+
/// If no page has yet been allocated for the specified address,
57+
/// get_mempolicy() will allocate a page as if the thread had
58+
/// performed a read (load) access to that address, and return the ID
59+
/// of the node where that page was allocated.
60+
///
61+
/// # Safety
62+
///
63+
/// This function operates on raw pointers, but it should only be used
64+
/// on memory which the caller owns.
65+
///
66+
/// # References
67+
/// - [Linux]
68+
///
69+
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
70+
#[cfg(linux_kernel)]
71+
#[inline]
72+
pub unsafe fn get_mempolicy_node(addr: *mut c_void) -> io::Result<usize> {
73+
backend::numa::syscalls::get_mempolicy_node(addr)
74+
}
75+
76+
/// `get_mempolicy_next_node(addr)`-Return node ID of the next node
77+
/// that will be used for interleaving of internal kernel pages
78+
/// allocated on behalf of the thread.
79+
///
80+
/// If flags specifies MPOL_F_NODE, but not MPOL_F_ADDR, and the
81+
/// thread's current policy is MPOL_INTERLEAVE, then get_mempolicy()
82+
/// will return in the location pointed to by a non-NULL mode
83+
/// argument, the node ID of the next node that will be used for
84+
/// interleaving of internal kernel pages allocated on behalf of the
85+
/// thread. These allocations include pages for memory-mapped files
86+
/// in process memory ranges mapped using the mmap(2) call with the
87+
/// MAP_PRIVATE flag for read accesses, and in memory ranges mapped
88+
/// with the MAP_SHARED flag for all accesses.
89+
///
90+
/// # Safety
91+
///
92+
/// This function operates on raw pointers, but it should only be used
93+
/// on memory which the caller owns.
94+
///
95+
/// # References
96+
/// - [Linux]
97+
///
98+
/// [Linux]: https://man7.org/linux/man-pages/man2/get_mempolicy.2.html
99+
#[cfg(linux_kernel)]
100+
#[inline]
101+
pub unsafe fn get_mempolicy_next_node() -> io::Result<usize> {
102+
backend::numa::syscalls::get_mempolicy_next_node()
103+
}
104+

tests/numa/main.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#[cfg(all(feature = "mm", feature = "fs"))]
2+
#[test]
3+
fn test_mbind() {
4+
let size = 8192;
5+
6+
unsafe {
7+
let vaddr = rustix::mm::mmap_anonymous(
8+
std::ptr::null_mut(),
9+
size,
10+
rustix::mm::ProtFlags::READ | rustix::mm::ProtFlags::WRITE,
11+
rustix::mm::MapFlags::PRIVATE,
12+
).unwrap();
13+
14+
vaddr.cast::<usize>().write(100);
15+
16+
let mask = &[1];
17+
rustix::numa::mbind(vaddr, size, rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
18+
mask, rustix::numa::ModeFlags::empty()).unwrap();
19+
20+
rustix::numa::get_mempolicy_node(vaddr).unwrap();
21+
22+
match rustix::numa::get_mempolicy_next_node() {
23+
Err(rustix::io::Errno::INVAL) => (),
24+
_ => panic!("rustix::numa::get_mempolicy_next_node() should return EINVAL for MPOL_DEFAULT")
25+
}
26+
27+
rustix::numa::set_mempolicy(rustix::numa::Mode::INTERLEAVE, mask).unwrap();
28+
29+
rustix::numa::get_mempolicy_next_node().unwrap();
30+
}
31+
}

tests/numa/main.rs~

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#[cfg(feature = "numa")]
2+
#[test]
3+
fn test_mbind() {
4+
let size = 8192;
5+
let fd = rustix::fs::memfd_create(
6+
"memfd",
7+
rustix::fs::MemfdFlags::CLOEXEC
8+
| rustix::fs::MemfdFlags::ALLOW_SEALING,
9+
).unwarp()
10+
11+
rustix::fs::ftruncate(&fd, size as u64).unwrap()
12+
13+
let vaddr = rustix::mm::mmap(
14+
std::ptr::null_mut(),
15+
size,
16+
rustix::mm::ProtFlags::empty(),
17+
rustix::mm::MapFlags::SHARED,
18+
&fd,
19+
0,
20+
)?;
21+
22+
let mask = &[1_usize];
23+
rustix::numa::mbind(vaddr, size, rustix::numa::Mode::BIND | rustix::numa::Mode::STATIC_NODES,
24+
&mask, rustix::numa::ModeFlags::default()).unwrap();
25+
}

0 commit comments

Comments
 (0)