diff --git a/src/handle.rs b/src/handle.rs index 0102789..8d0ce71 100644 --- a/src/handle.rs +++ b/src/handle.rs @@ -12,7 +12,7 @@ use netlink_packet_core::{ }; use netlink_packet_generic::{GenlFamily, GenlHeader, GenlMessage}; use netlink_proto::{sys::SocketAddr, ConnectionHandle}; -use std::{fmt::Debug, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, sync::Arc}; /// The generic netlink connection handle /// @@ -36,13 +36,14 @@ use std::{fmt::Debug, sync::Arc}; /// 2. Query the family id using the builtin resolver. /// 3. If the id is in the cache, returning the id in the cache and skip step 4. /// 4. The resolver sends `CTRL_CMD_GETFAMILY` request to get the id and records -/// it in the cache. 5. fill the family id using -/// [`GenlMessage::set_resolved_family_id()`]. 6. Serialize the payload to -/// [`RawGenlMessage`]. 7. Send it through the connection. -/// - The family id filled into `message_type` field in -/// [`NetlinkMessage::finalize()`]. +/// it in the cache. +/// 5. fill the family id using [`GenlMessage::set_resolved_family_id()`]. +/// 6. Serialize the payload to [`RawGenlMessage`]. +/// 7. Send it through the connection. +/// - The family id filled into `message_type` field in +/// [`NetlinkMessage::finalize()`]. /// 8. In the response stream, deserialize the payload back to -/// [`GenlMessage`]. +/// [`GenlMessage`]. #[derive(Clone, Debug)] pub struct GenetlinkHandle { handle: ConnectionHandle, @@ -69,7 +70,21 @@ impl GenetlinkHandle { .await } - /// Clear the resolver's fanily id cache + /// Resolve the multicast groups of the given [`GenlFamily`]. + pub async fn resolve_mcast_groups( + &self, + ) -> Result, GenetlinkError> + where + F: GenlFamily, + { + self.resolver + .lock() + .await + .query_family_multicast_groups(self, F::family_name()) + .await + } + + /// Clear the resolver's family id cache pub async fn clear_family_id_cache(&self) { self.resolver.lock().await.clear_cache(); } diff --git a/src/resolver.rs b/src/resolver.rs index ee46ae0..90dcfdc 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -4,7 +4,10 @@ use crate::{error::GenetlinkError, GenetlinkHandle}; use futures::{future::Either, StreamExt}; use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST}; use netlink_packet_generic::{ - ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + ctrl::{ + nlas::{GenlCtrlAttrs, McastGrpAttrs}, + GenlCtrl, GenlCtrlCmd, + }, GenlMessage, }; use std::{collections::HashMap, future::Future}; @@ -12,12 +15,14 @@ use std::{collections::HashMap, future::Future}; #[derive(Clone, Debug, Default)] pub struct Resolver { cache: HashMap<&'static str, u16>, + groups_cache: HashMap<&'static str, HashMap>, } impl Resolver { pub fn new() -> Self { Self { cache: HashMap::new(), + groups_cache: HashMap::new(), } } @@ -25,6 +30,13 @@ impl Resolver { self.cache.get(family_name).copied() } + pub fn get_groups_cache_by_name( + &self, + family_name: &str, + ) -> Option> { + self.groups_cache.get(family_name).cloned() + } + pub fn query_family_id( &mut self, handle: &GenetlinkHandle, @@ -85,8 +97,86 @@ impl Resolver { } } + pub fn query_family_multicast_groups( + &mut self, + handle: &GenetlinkHandle, + family_name: &'static str, + ) -> impl Future, GenetlinkError>> + '_ + { + if let Some(groups) = self.get_groups_cache_by_name(family_name) { + Either::Left(futures::future::ready(Ok(groups))) + } else { + let mut handle = handle.clone(); + Either::Right(async move { + // Create the request message to get family details + let mut genlmsg: GenlMessage = + GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName( + family_name.to_owned(), + )], + }); + genlmsg.finalize(); + let mut nlmsg = NetlinkMessage::from(genlmsg); + nlmsg.header.flags = NLM_F_REQUEST; + nlmsg.finalize(); + + // Send the request + let mut res = handle.send_request(nlmsg)?; + + // Prepare to collect multicast groups + let mut mc_groups = HashMap::new(); + + // Process the response + while let Some(result) = res.next().await { + match result?.payload { + NetlinkPayload::InnerMessage(genlmsg) => { + // One specific family id was requested, it can be + // assumed, that the mcast + // groups are part of that family. + let Some(mcast_groups) = genlmsg + .payload + .nlas + .into_iter() + .filter_map(|attr| match attr { + GenlCtrlAttrs::McastGroups(groups) => { + Some(groups) + } + _ => None, + }) + .next() + else { + continue; + }; + + for group in mcast_groups.into_iter().filter_map(|attrs| { + match attrs.as_slice() { + [McastGrpAttrs::Name(name), McastGrpAttrs::Id(i)] | + [McastGrpAttrs::Id(i), McastGrpAttrs::Name(name)] => Some((name.clone(), *i)), + _ => None + } + }) { + mc_groups.insert(group.0, group.1); + } + } + NetlinkPayload::Error(e) => { + return Err(e.into()); + } + _ => (), + } + } + + // Update the cache + self.groups_cache.insert(family_name, mc_groups.clone()); + + Ok(mc_groups) + }) + } + } + pub fn clear_cache(&mut self) { self.cache.clear(); + self.groups_cache.clear(); } } @@ -143,15 +233,34 @@ mod test { }) .unwrap(); if id == 0 { - log::warn!( - "Generic family \"{name}\" not exist or not loaded \ - in this environment. Ignored." - ); continue; } let cache = resolver.get_cache_by_name(name).unwrap(); assert_eq!(id, cache); + + let mcast_groups = resolver + .query_family_multicast_groups(&handle, name) + .await + .or_else(|e| { + if let GenetlinkError::NetlinkError(io_err) = &e { + if io_err.kind() == ErrorKind::NotFound { + // Ignore non exist entries + Ok(HashMap::new()) + } else { + Err(e) + } + } else { + Err(e) + } + }) + .unwrap(); + if mcast_groups.is_empty() { + continue; + } + + let cache = resolver.get_groups_cache_by_name(name).unwrap(); + assert_eq!(mcast_groups, cache); log::warn!("{:?}", (name, cache)); } }