Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions src/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use netlink_packet_core::{
};
use netlink_packet_generic::{GenlFamily, GenlHeader, GenlMessage};
use netlink_proto::{sys::SocketAddr, ConnectionHandle};
use std::{fmt::Debug, sync::Arc};
use std::{collections::HashMap, fmt::Debug, sync::Arc};

/// The generic netlink connection handle
///
Expand All @@ -36,13 +36,14 @@ use std::{fmt::Debug, sync::Arc};
/// 2. Query the family id using the builtin resolver.
/// 3. If the id is in the cache, returning the id in the cache and skip step 4.
/// 4. The resolver sends `CTRL_CMD_GETFAMILY` request to get the id and records
/// it in the cache. 5. fill the family id using
/// [`GenlMessage::set_resolved_family_id()`]. 6. Serialize the payload to
/// [`RawGenlMessage`]. 7. Send it through the connection.
/// - The family id filled into `message_type` field in
/// [`NetlinkMessage::finalize()`].
/// it in the cache.
/// 5. fill the family id using [`GenlMessage::set_resolved_family_id()`].
/// 6. Serialize the payload to [`RawGenlMessage`].
/// 7. Send it through the connection.
/// - The family id filled into `message_type` field in
/// [`NetlinkMessage::finalize()`].
/// 8. In the response stream, deserialize the payload back to
/// [`GenlMessage<F>`].
/// [`GenlMessage<F>`].
#[derive(Clone, Debug)]
pub struct GenetlinkHandle {
handle: ConnectionHandle<RawGenlMessage>,
Expand All @@ -69,7 +70,21 @@ impl GenetlinkHandle {
.await
}

/// Clear the resolver's fanily id cache
/// Resolve the multicast groups of the given [`GenlFamily`].
pub async fn resolve_mcast_groups<F>(
&self,
) -> Result<HashMap<String, u32>, GenetlinkError>
where
F: GenlFamily,
{
self.resolver
.lock()
.await
.query_family_multicast_groups(self, F::family_name())
.await
}

/// Clear the resolver's family id cache
pub async fn clear_family_id_cache(&self) {
self.resolver.lock().await.clear_cache();
}
Expand Down
119 changes: 114 additions & 5 deletions src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,39 @@ use crate::{error::GenetlinkError, GenetlinkHandle};
use futures::{future::Either, StreamExt};
use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST};
use netlink_packet_generic::{
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
ctrl::{
nlas::{GenlCtrlAttrs, McastGrpAttrs},
GenlCtrl, GenlCtrlCmd,
},
GenlMessage,
};
use std::{collections::HashMap, future::Future};

#[derive(Clone, Debug, Default)]
pub struct Resolver {
cache: HashMap<&'static str, u16>,
groups_cache: HashMap<&'static str, HashMap<String, u32>>,
}

impl Resolver {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
groups_cache: HashMap::new(),
}
}

pub fn get_cache_by_name(&self, family_name: &str) -> Option<u16> {
self.cache.get(family_name).copied()
}

pub fn get_groups_cache_by_name(
&self,
family_name: &str,
) -> Option<HashMap<String, u32>> {
self.groups_cache.get(family_name).cloned()
}

pub fn query_family_id(
&mut self,
handle: &GenetlinkHandle,
Expand Down Expand Up @@ -85,8 +97,86 @@ impl Resolver {
}
}

pub fn query_family_multicast_groups(
&mut self,
handle: &GenetlinkHandle,
family_name: &'static str,
) -> impl Future<Output = Result<HashMap<String, u32>, GenetlinkError>> + '_
{
if let Some(groups) = self.get_groups_cache_by_name(family_name) {
Either::Left(futures::future::ready(Ok(groups)))
} else {
let mut handle = handle.clone();
Either::Right(async move {
// Create the request message to get family details
let mut genlmsg: GenlMessage<GenlCtrl> =
GenlMessage::from_payload(GenlCtrl {
cmd: GenlCtrlCmd::GetFamily,
nlas: vec![GenlCtrlAttrs::FamilyName(
family_name.to_owned(),
)],
});
genlmsg.finalize();
let mut nlmsg = NetlinkMessage::from(genlmsg);
nlmsg.header.flags = NLM_F_REQUEST;
nlmsg.finalize();

// Send the request
let mut res = handle.send_request(nlmsg)?;

// Prepare to collect multicast groups
let mut mc_groups = HashMap::new();

// Process the response
while let Some(result) = res.next().await {
match result?.payload {
NetlinkPayload::InnerMessage(genlmsg) => {
// One specific family id was requested, it can be
// assumed, that the mcast
// groups are part of that family.
let Some(mcast_groups) = genlmsg
.payload
.nlas
.into_iter()
.filter_map(|attr| match attr {
GenlCtrlAttrs::McastGroups(groups) => {
Some(groups)
}
_ => None,
})
.next()
else {
continue;
};

for group in mcast_groups.into_iter().filter_map(|attrs| {
match attrs.as_slice() {
[McastGrpAttrs::Name(name), McastGrpAttrs::Id(i)] |
[McastGrpAttrs::Id(i), McastGrpAttrs::Name(name)] => Some((name.clone(), *i)),
_ => None
}
}) {
mc_groups.insert(group.0, group.1);
}
}
NetlinkPayload::Error(e) => {
return Err(e.into());
}
_ => (),
}
}

// Update the cache
self.groups_cache.insert(family_name, mc_groups.clone());

Ok(mc_groups)
})
}
}

pub fn clear_cache(&mut self) {
self.cache.clear();
self.groups_cache.clear();
}
}

Expand Down Expand Up @@ -143,15 +233,34 @@ mod test {
})
.unwrap();
if id == 0 {
log::warn!(
"Generic family \"{name}\" not exist or not loaded \
in this environment. Ignored."
);
continue;
}

let cache = resolver.get_cache_by_name(name).unwrap();
assert_eq!(id, cache);

let mcast_groups = resolver
.query_family_multicast_groups(&handle, name)
.await
.or_else(|e| {
if let GenetlinkError::NetlinkError(io_err) = &e {
if io_err.kind() == ErrorKind::NotFound {
// Ignore non exist entries
Ok(HashMap::new())
} else {
Err(e)
}
} else {
Err(e)
}
})
.unwrap();
if mcast_groups.is_empty() {
continue;
}

let cache = resolver.get_groups_cache_by_name(name).unwrap();
assert_eq!(mcast_groups, cache);
log::warn!("{:?}", (name, cache));
}
}
Expand Down