diff --git a/Cargo.lock b/Cargo.lock index cee10b1..a87315b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,8 +209,10 @@ dependencies = [ "anyhow", "itertools", "netlink-packet-core", + "netlink-packet-generic", "netlink-packet-route", "netlink-packet-sock-diag", + "netlink-packet-wireguard", "netlink-sys", "procfs", "terminal_size", @@ -244,6 +246,18 @@ dependencies = [ "netlink-packet-utils", ] +[[package]] +name = "netlink-packet-generic" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd7eb8ad331c84c6b8cb7f685b448133e5ad82e1ffd5acafac374af4a5a308b" +dependencies = [ + "anyhow", + "byteorder", + "netlink-packet-core", + "netlink-packet-utils", +] + [[package]] name = "netlink-packet-route" version = "0.17.1" @@ -285,6 +299,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "netlink-packet-wireguard" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b25b050ff1f6a1e23c6777b72db22790fe5b6b5ccfd3858672587a79876c8f" +dependencies = [ + "anyhow", + "byteorder", + "libc", + "log", + "netlink-packet-generic", + "netlink-packet-utils", +] + [[package]] name = "netlink-sys" version = "0.8.5" diff --git a/Cargo.toml b/Cargo.toml index 33d79ee..8618448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,8 +15,10 @@ anstyle = "1.0.8" anyhow = "1.0.86" itertools = "0.13.0" netlink-packet-core = "=0.7.0" +netlink-packet-generic = "0.3.3" netlink-packet-route = "=0.17.1" netlink-packet-sock-diag = "=0.4.2" +netlink-packet-wireguard = "0.2.3" netlink-sys = "=0.8.5" procfs = "0.16.0" terminal_size = "0.3.0" diff --git a/src/main.rs b/src/main.rs index 3968d77..7431634 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,10 +6,13 @@ mod termtree; use anyhow::Result; use itertools::Itertools; -use netlink::sock::{Family, SockInfo}; +use netlink::{ + sock::{Family, SockInfo}, + wg::wireguards, +}; use procfs::process::all_processes; use std::{ - collections::BTreeMap, + collections::{BTreeMap, HashMap}, env::var_os, io::{stdout, BufWriter, Write}, net::{IpAddr, Ipv4Addr, Ipv6Addr}, @@ -21,14 +24,14 @@ pub type Ino = u64; fn main() -> Result<()> { let users_cache = UsersCache::new(); - let (interfaces, local_routes) = interfaces_routes(); + let iface_info = interfaces_routes(); - let filters = options::parse_args(&interfaces, &local_routes, &users_cache)?; + let filters = options::parse_args(&iface_info, &users_cache)?; - let socks = netlink::sock::all_sockets(&interfaces, &local_routes); // TODO no clone, pass filters + let socks = netlink::sock::all_sockets(&iface_info); // TODO no clone, pass filters let mut socks = match socks { Ok(socks) => socks, - Err(netlink_err) => match sockets_procfs::all_sockets(&interfaces, &local_routes) { + Err(netlink_err) => match sockets_procfs::all_sockets(&iface_info) { Ok(socks) => socks, Err(proc_err) => { eprintln!( @@ -42,6 +45,8 @@ fn main() -> Result<()> { }; let mut output = termtree::Tree::new(); let self_user_ns = procs::get_user_ns(&procs::ourself()?).ok(); + + // output known processes/sockets let mut lps = all_processes()? .filter_map(|p| procs::ProcDesc::inspect_ps(p, &mut socks, &users_cache, self_user_ns).ok()) .filter(|p| !p.sockets.is_empty()) @@ -60,6 +65,31 @@ fn main() -> Result<()> { ); } } + + // output wireguards + let mut wireguard_sockets = HashMap::<_, Vec<_>>::new(); + socks.retain(|_sockid, sockinfo| { + if let Some(if_id) = iface_info.wireguard_ports.get(&sockinfo.port) { + wireguard_sockets + .entry(if_id) + .or_default() + .push(sockinfo.to_owned()); + false + } else { + true + } + }); + for (if_id, socks) in &wireguard_sockets { + if filters.accept_wg() { + let name = match iface_info.id2name.get(if_id) { + Some(ifname) => format!("[wireguard {ifname}]"), + None => format!("wireguard, index {if_id}"), + }; + output.node(name, sockets_tree(socks, &filters)); + } + } + + // output unknown sockets let mut socks = socks .values() .into_group_map_by(|s| s.uid) @@ -76,8 +106,10 @@ fn main() -> Result<()> { } } false => { - eprintln!("WARNING: Some listening sockets hidden:"); - eprintln!("Not all sockets could not be matched to a process, process-based filtering not fully possible."); + if !socks.is_empty() { + eprintln!("WARNING: Some listening sockets hidden:"); + eprintln!("Not all sockets could not be matched to a process, process-based filtering not fully possible."); + } } } @@ -91,13 +123,28 @@ fn main() -> Result<()> { Ok(()) } -fn interfaces_routes() -> (std::collections::HashMap, netlink::route::Rtbl) { +#[derive(Default)] +struct IfaceInfo { + id2name: HashMap, + wireguard_ports: HashMap, + local_routes: netlink::route::Rtbl, +} + +fn interfaces_routes() -> IfaceInfo { let Ok(ref route_socket) = netlink::route::socket() else { return Default::default(); }; - let interfaces = netlink::route::interface_names(route_socket).unwrap_or_default(); + let netlink::route::Interfaces { + id2name, + wireguard_ids, + } = netlink::route::interface_names(route_socket).unwrap_or_default(); let local_routes = netlink::route::local_routes(route_socket).unwrap_or_default(); - (interfaces, local_routes) + let wireguard_ports = wireguards(&wireguard_ids).unwrap_or_default(); + IfaceInfo { + id2name, + wireguard_ports, + local_routes, + } } fn sockets_tree<'a>( diff --git a/src/netlink/mod.rs b/src/netlink/mod.rs index c2b8c40..d01dcd5 100644 --- a/src/netlink/mod.rs +++ b/src/netlink/mod.rs @@ -1,5 +1,6 @@ pub mod route; pub mod sock; +pub mod wg; use anyhow::{Context, Result}; use netlink_packet_core::{ @@ -32,7 +33,10 @@ where match rx_packet.payload { NetlinkPayload::Done(_) => return Ok(()), NetlinkPayload::InnerMessage(inner) => recv(inner), - NetlinkPayload::Error(err) => return Err(err.to_io()).context("Netlink error"), + NetlinkPayload::Error(err) => match err.code { + Some(_) => return Err(err.to_io()).context("Netlink error"), + None => return Ok(()), + }, p => todo!("Unexpected netlink payload {:?}", p.message_type()), } diff --git a/src/netlink/route.rs b/src/netlink/route.rs index b6184ee..18b9010 100644 --- a/src/netlink/route.rs +++ b/src/netlink/route.rs @@ -4,13 +4,22 @@ use netlink_packet_core::{ NetlinkHeader, NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_REQUEST, }; use netlink_packet_route::{ - constants::*, link::nlas::Nla as LinkNla, route::nlas::Nla as RouteNla, LinkMessage, - RouteMessage, RtnlMessage, + constants::*, + link::nlas::Nla as LinkNla, + nlas::link::{Info, InfoKind}, + route::nlas::Nla as RouteNla, + LinkMessage, RouteMessage, RtnlMessage, }; use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr}; use std::{cmp::Reverse, collections::HashMap, net::IpAddr}; -pub fn interface_names(socket: &Socket) -> Result> { +#[derive(Default)] +pub struct Interfaces { + pub id2name: HashMap, + pub wireguard_ids: Vec, +} + +pub fn interface_names(socket: &Socket) -> Result { let mut packet = NetlinkMessage::new( NetlinkHeader::default(), NetlinkPayload::from(RtnlMessage::GetLink(LinkMessage::default())), @@ -19,19 +28,32 @@ pub fn interface_names(socket: &Socket) -> Result> { packet.header.sequence_number = 1; let mut map = HashMap::new(); + let mut wg_ids = Vec::new(); drive_req(packet, socket, |inner| { if let RtnlMessage::NewLink(nl) = inner { - if let Some(name) = nl.nlas.into_iter().find_map(|s| match s { - LinkNla::IfName(n) => Some(n), - _ => None, - }) { - map.insert(nl.header.index, name); + for nla in nl.nlas { + match nla { + LinkNla::IfName(name) => { + map.insert(nl.header.index, name); + } + LinkNla::Info(infos) => { + for info in infos { + if info == Info::Kind(InfoKind::Wireguard) { + wg_ids.push(nl.header.index); + } + } + } + _ => (), + } } } }) .context("Get interface names")?; - Ok(map) + Ok(Interfaces { + id2name: map, + wireguard_ids: wg_ids, + }) } pub fn socket() -> Result { diff --git a/src/netlink/sock.rs b/src/netlink/sock.rs index 518f670..91d53ff 100644 --- a/src/netlink/sock.rs +++ b/src/netlink/sock.rs @@ -1,5 +1,5 @@ use super::{drive_req, nl_hdr_flags, route::Rtbl}; -use crate::Ino; +use crate::{IfaceInfo, Ino}; use anyhow::{Context, Result}; use netlink_packet_core::{NetlinkMessage, NLM_F_DUMP, NLM_F_REQUEST}; use netlink_packet_sock_diag::{ @@ -11,8 +11,11 @@ use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr}; use std::{collections::HashMap, fmt::Display, net::IpAddr}; pub fn all_sockets<'i>( - interfaces: &'i HashMap, - local_routes: &Rtbl, + IfaceInfo { + id2name: interfaces, + local_routes, + .. + }: &'i IfaceInfo, ) -> Result>> { let mut socket = Socket::new(NETLINK_SOCK_DIAG).context("Construct netlink socket information socket")?; @@ -149,7 +152,7 @@ impl std::str::FromStr for Protocol { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct SockInfo<'a> { pub family: Family, pub protocol: Protocol, diff --git a/src/netlink/wg.rs b/src/netlink/wg.rs new file mode 100644 index 0000000..617e841 --- /dev/null +++ b/src/netlink/wg.rs @@ -0,0 +1,70 @@ +use super::drive_req; +use anyhow::{Context, Result}; +use netlink_packet_core::{NetlinkHeader, NetlinkMessage, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST}; +use netlink_packet_generic::{ + ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd}, + GenlMessage, +}; +use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd}; +use netlink_sys::{protocols::NETLINK_GENERIC, Socket, SocketAddr}; +use std::collections::HashMap; + +pub fn wireguards(interface_ids: &[u32]) -> Result> { + if interface_ids.is_empty() { + return Ok(Default::default()); + } + + let mut socket = Socket::new(NETLINK_GENERIC).context("Construct netlink generic socket")?; + socket.bind_auto().context("Bind netlink generic socket")?; + socket + .connect(&SocketAddr::new(0, 0)) + .context("Connect netlink generic socket")?; + + // Resolve wireguard family id. + // genetlink can do this for me, but it's all async and tokio based. + let mut packet = NetlinkMessage::new( + NetlinkHeader::default(), + GenlMessage::from_payload(GenlCtrl { + cmd: GenlCtrlCmd::GetFamily, + nlas: vec![GenlCtrlAttrs::FamilyName("wireguard".into())], + }) + .into(), + ); + packet.header.flags = NLM_F_REQUEST | NLM_F_ACK; + packet.header.sequence_number = 1; + let mut family_id: Option = None; + drive_req(packet, &socket, |inner| { + for nla in inner.payload.nlas { + if let GenlCtrlAttrs::FamilyId(id) = nla { + family_id = Some(id); + } + } + }) + .context("Get wireguard family")?; + let family_id = family_id.context("Netlink wireguard family not found")?; + + let mut ret = HashMap::new(); + for &if_id in interface_ids { + let mut payload = GenlMessage::from_payload(Wireguard { + cmd: WireguardCmd::GetDevice, + nlas: vec![WgDeviceAttrs::IfIndex(if_id)], + }); + payload.set_resolved_family_id(family_id); + let mut packet = NetlinkMessage::new(NetlinkHeader::default(), payload.into()); + packet.header.flags = NLM_F_DUMP | NLM_F_REQUEST | NLM_F_ACK; + packet.header.sequence_number = 2; + + drive_req(packet, &socket, |inner| { + for nla in inner.payload.nlas { + if let WgDeviceAttrs::ListenPort(port) = nla { + if let Some(other_id) = ret.insert(port, if_id) { + eprintln!("WARNING: Wireguard interfaces {if_id} and {other_id} seem to be listening on the same port {port}. Output may be inaccurate"); + } + } + } + }) + .context("Get wireguard ")?; + } + + Ok(ret) +} diff --git a/src/options.rs b/src/options.rs index 12de316..9493f8e 100644 --- a/src/options.rs +++ b/src/options.rs @@ -1,7 +1,7 @@ use crate::netlink::route::Prefix; -use crate::netlink::route::Rtbl; use crate::netlink::sock::Protocol; use crate::procs; +use crate::IfaceInfo; use anyhow::bail; use anyhow::Context; use anyhow::Result; @@ -83,6 +83,10 @@ impl Filters { || self.pfxs.iter().any(|pfx| pfx.matches(addr)) || addr.is_unspecified() } + + pub(crate) fn accept_wg(&self) -> bool { + self.cmd.is_empty() && self.pid.is_empty() + } } pub fn match_arg(arg: &str, args: &mut std::env::Args) -> Result> { @@ -115,8 +119,11 @@ pub fn match_arg(arg: &str, args: &mut std::env::Args) -> Result, - local_routes: &Rtbl, + IfaceInfo { + id2name: ifaces, + local_routes, + .. + }: &IfaceInfo, users: &UsersCache, ) -> Result { let ifaces = ifaces diff --git a/src/procs.rs b/src/procs.rs index 03b533a..ba1b24c 100644 --- a/src/procs.rs +++ b/src/procs.rs @@ -351,7 +351,8 @@ impl Ord for ProcDesc<'_> { pub fn get_user_ns(p: &Process) -> Result { Ok(p.namespaces() .context("Namespaces inaccessible")? - .0.get(&OsString::from_vec(b"user".to_vec())) + .0 + .get(&OsString::from_vec(b"user".to_vec())) .context("No user ns")? .identifier) } diff --git a/src/sockets_procfs.rs b/src/sockets_procfs.rs index e703600..01918a7 100644 --- a/src/sockets_procfs.rs +++ b/src/sockets_procfs.rs @@ -1,14 +1,17 @@ use super::Ino; -use crate::netlink::{ - route::Rtbl, - sock::{Family, Protocol, SockInfo}, +use crate::{ + netlink::sock::{Family, Protocol, SockInfo}, + IfaceInfo, }; use anyhow::{Context, Result}; use std::collections::HashMap; pub fn all_sockets<'i>( - interfaces: &'i HashMap, - local_routes: &Rtbl, + IfaceInfo { + id2name: interfaces, + local_routes, + .. + }: &'i IfaceInfo, ) -> Result>> { eprintln!("WARNING: Falling back to parsing info from procfs, limited to TCP and UDP"); let mut ret = HashMap::new(); diff --git a/src/termtree.rs b/src/termtree.rs index 0ab9269..0159021 100644 --- a/src/termtree.rs +++ b/src/termtree.rs @@ -91,10 +91,7 @@ fn render_entry( ret(b"\n"); if collapsed.is_none() { for (pos, child) in tree.children.0.iter().with_position() { - let last = matches!( - pos, - itertools::Position::Last | itertools::Position::Only - ); + let last = matches!(pos, itertools::Position::Last | itertools::Position::Only); let prefix = Prefix { last, prefix }; render_entry(child, mw, color, ret, Some(&prefix)); }