Skip to content

Commit

Permalink
Attribute wireguard listening ports
Browse files Browse the repository at this point in the history
  • Loading branch information
jcaesar committed Aug 12, 2024
1 parent cf893cd commit 5668adc
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 38 deletions.
28 changes: 28 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ anstyle = "1.0.8"
anyhow = "1.0.86"
itertools = "0.13.0"
netlink-packet-core = "=0.7.0"
netlink-packet-generic = "0.3.3"
netlink-packet-route = "=0.17.1"
netlink-packet-sock-diag = "=0.4.2"
netlink-packet-wireguard = "0.2.3"
netlink-sys = "=0.8.5"
procfs = "0.16.0"
terminal_size = "0.3.0"
Expand Down
69 changes: 58 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ mod termtree;

use anyhow::Result;
use itertools::Itertools;
use netlink::sock::{Family, SockInfo};
use netlink::{
sock::{Family, SockInfo},
wg::wireguards,
};
use procfs::process::all_processes;
use std::{
collections::BTreeMap,
collections::{BTreeMap, HashMap},
env::var_os,
io::{stdout, BufWriter, Write},
net::{IpAddr, Ipv4Addr, Ipv6Addr},
Expand All @@ -21,14 +24,14 @@ pub type Ino = u64;

fn main() -> Result<()> {
let users_cache = UsersCache::new();
let (interfaces, local_routes) = interfaces_routes();
let iface_info = interfaces_routes();

let filters = options::parse_args(&interfaces, &local_routes, &users_cache)?;
let filters = options::parse_args(&iface_info, &users_cache)?;

let socks = netlink::sock::all_sockets(&interfaces, &local_routes); // TODO no clone, pass filters
let socks = netlink::sock::all_sockets(&iface_info); // TODO no clone, pass filters
let mut socks = match socks {
Ok(socks) => socks,
Err(netlink_err) => match sockets_procfs::all_sockets(&interfaces, &local_routes) {
Err(netlink_err) => match sockets_procfs::all_sockets(&iface_info) {
Ok(socks) => socks,
Err(proc_err) => {
eprintln!(
Expand All @@ -42,6 +45,8 @@ fn main() -> Result<()> {
};
let mut output = termtree::Tree::new();
let self_user_ns = procs::get_user_ns(&procs::ourself()?).ok();

// output known processes/sockets
let mut lps = all_processes()?
.filter_map(|p| procs::ProcDesc::inspect_ps(p, &mut socks, &users_cache, self_user_ns).ok())
.filter(|p| !p.sockets.is_empty())
Expand All @@ -60,6 +65,31 @@ fn main() -> Result<()> {
);
}
}

// output wireguards
let mut wireguard_sockets = HashMap::<_, Vec<_>>::new();
socks.retain(|_sockid, sockinfo| {
if let Some(if_id) = iface_info.wireguard_ports.get(&sockinfo.port) {
wireguard_sockets
.entry(if_id)
.or_default()
.push(sockinfo.to_owned());
false
} else {
true
}
});
for (if_id, socks) in &wireguard_sockets {
if filters.accept_wg() {
let name = match iface_info.id2name.get(if_id) {
Some(ifname) => format!("[wireguard {ifname}]"),
None => format!("wireguard, index {if_id}"),
};
output.node(name, sockets_tree(socks, &filters));
}
}

// output unknown sockets
let mut socks = socks
.values()
.into_group_map_by(|s| s.uid)
Expand All @@ -76,8 +106,10 @@ fn main() -> Result<()> {
}
}
false => {
eprintln!("WARNING: Some listening sockets hidden:");
eprintln!("Not all sockets could not be matched to a process, process-based filtering not fully possible.");
if !socks.is_empty() {
eprintln!("WARNING: Some listening sockets hidden:");
eprintln!("Not all sockets could not be matched to a process, process-based filtering not fully possible.");
}
}
}

Expand All @@ -91,13 +123,28 @@ fn main() -> Result<()> {
Ok(())
}

fn interfaces_routes() -> (std::collections::HashMap<u32, String>, netlink::route::Rtbl) {
#[derive(Default)]
struct IfaceInfo {
id2name: HashMap<u32, String>,
wireguard_ports: HashMap<u16, u32>,
local_routes: netlink::route::Rtbl,
}

fn interfaces_routes() -> IfaceInfo {
let Ok(ref route_socket) = netlink::route::socket() else {
return Default::default();
};
let interfaces = netlink::route::interface_names(route_socket).unwrap_or_default();
let netlink::route::Interfaces {
id2name,
wireguard_ids,
} = netlink::route::interface_names(route_socket).unwrap_or_default();
let local_routes = netlink::route::local_routes(route_socket).unwrap_or_default();
(interfaces, local_routes)
let wireguard_ports = wireguards(&wireguard_ids).unwrap_or_default();
IfaceInfo {
id2name,
wireguard_ports,
local_routes,
}
}

fn sockets_tree<'a>(
Expand Down
6 changes: 5 additions & 1 deletion src/netlink/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod route;
pub mod sock;
pub mod wg;

use anyhow::{Context, Result};
use netlink_packet_core::{
Expand Down Expand Up @@ -32,7 +33,10 @@ where
match rx_packet.payload {
NetlinkPayload::Done(_) => return Ok(()),
NetlinkPayload::InnerMessage(inner) => recv(inner),
NetlinkPayload::Error(err) => return Err(err.to_io()).context("Netlink error"),
NetlinkPayload::Error(err) => match err.code {
Some(_) => return Err(err.to_io()).context("Netlink error"),
None => return Ok(()),
},
p => todo!("Unexpected netlink payload {:?}", p.message_type()),
}

Expand Down
40 changes: 31 additions & 9 deletions src/netlink/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@ use netlink_packet_core::{
NetlinkHeader, NetlinkMessage, NetlinkPayload, NLM_F_DUMP, NLM_F_REQUEST,
};
use netlink_packet_route::{
constants::*, link::nlas::Nla as LinkNla, route::nlas::Nla as RouteNla, LinkMessage,
RouteMessage, RtnlMessage,
constants::*,
link::nlas::Nla as LinkNla,
nlas::link::{Info, InfoKind},
route::nlas::Nla as RouteNla,
LinkMessage, RouteMessage, RtnlMessage,
};
use netlink_sys::{protocols::NETLINK_ROUTE, Socket, SocketAddr};
use std::{cmp::Reverse, collections::HashMap, net::IpAddr};

pub fn interface_names(socket: &Socket) -> Result<HashMap<u32, String>> {
#[derive(Default)]
pub struct Interfaces {
pub id2name: HashMap<u32, String>,
pub wireguard_ids: Vec<u32>,
}

pub fn interface_names(socket: &Socket) -> Result<Interfaces> {
let mut packet = NetlinkMessage::new(
NetlinkHeader::default(),
NetlinkPayload::from(RtnlMessage::GetLink(LinkMessage::default())),
Expand All @@ -19,19 +28,32 @@ pub fn interface_names(socket: &Socket) -> Result<HashMap<u32, String>> {
packet.header.sequence_number = 1;

let mut map = HashMap::new();
let mut wg_ids = Vec::new();
drive_req(packet, socket, |inner| {
if let RtnlMessage::NewLink(nl) = inner {
if let Some(name) = nl.nlas.into_iter().find_map(|s| match s {
LinkNla::IfName(n) => Some(n),
_ => None,
}) {
map.insert(nl.header.index, name);
for nla in nl.nlas {
match nla {
LinkNla::IfName(name) => {
map.insert(nl.header.index, name);
}
LinkNla::Info(infos) => {
for info in infos {
if info == Info::Kind(InfoKind::Wireguard) {
wg_ids.push(nl.header.index);
}
}
}
_ => (),
}
}
}
})
.context("Get interface names")?;

Ok(map)
Ok(Interfaces {
id2name: map,
wireguard_ids: wg_ids,
})
}

pub fn socket() -> Result<Socket> {
Expand Down
11 changes: 7 additions & 4 deletions src/netlink/sock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{drive_req, nl_hdr_flags, route::Rtbl};
use crate::Ino;
use crate::{IfaceInfo, Ino};
use anyhow::{Context, Result};
use netlink_packet_core::{NetlinkMessage, NLM_F_DUMP, NLM_F_REQUEST};
use netlink_packet_sock_diag::{
Expand All @@ -11,8 +11,11 @@ use netlink_sys::{protocols::NETLINK_SOCK_DIAG, Socket, SocketAddr};
use std::{collections::HashMap, fmt::Display, net::IpAddr};

pub fn all_sockets<'i>(
interfaces: &'i HashMap<u32, String>,
local_routes: &Rtbl,
IfaceInfo {
id2name: interfaces,
local_routes,
..
}: &'i IfaceInfo,
) -> Result<HashMap<Ino, SockInfo<'i>>> {
let mut socket =
Socket::new(NETLINK_SOCK_DIAG).context("Construct netlink socket information socket")?;
Expand Down Expand Up @@ -149,7 +152,7 @@ impl std::str::FromStr for Protocol {
}
}

#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SockInfo<'a> {
pub family: Family,
pub protocol: Protocol,
Expand Down
70 changes: 70 additions & 0 deletions src/netlink/wg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use super::drive_req;
use anyhow::{Context, Result};
use netlink_packet_core::{NetlinkHeader, NetlinkMessage, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST};
use netlink_packet_generic::{
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
GenlMessage,
};
use netlink_packet_wireguard::{nlas::WgDeviceAttrs, Wireguard, WireguardCmd};
use netlink_sys::{protocols::NETLINK_GENERIC, Socket, SocketAddr};
use std::collections::HashMap;

pub fn wireguards(interface_ids: &[u32]) -> Result<HashMap<u16, u32>> {
if interface_ids.is_empty() {
return Ok(Default::default());
}

let mut socket = Socket::new(NETLINK_GENERIC).context("Construct netlink generic socket")?;
socket.bind_auto().context("Bind netlink generic socket")?;
socket
.connect(&SocketAddr::new(0, 0))
.context("Connect netlink generic socket")?;

// Resolve wireguard family id.
// genetlink can do this for me, but it's all async and tokio based.
let mut packet = NetlinkMessage::new(
NetlinkHeader::default(),
GenlMessage::from_payload(GenlCtrl {
cmd: GenlCtrlCmd::GetFamily,
nlas: vec![GenlCtrlAttrs::FamilyName("wireguard".into())],
})
.into(),
);
packet.header.flags = NLM_F_REQUEST | NLM_F_ACK;
packet.header.sequence_number = 1;
let mut family_id: Option<u16> = None;
drive_req(packet, &socket, |inner| {
for nla in inner.payload.nlas {
if let GenlCtrlAttrs::FamilyId(id) = nla {
family_id = Some(id);
}
}
})
.context("Get wireguard family")?;
let family_id = family_id.context("Netlink wireguard family not found")?;

let mut ret = HashMap::new();
for &if_id in interface_ids {
let mut payload = GenlMessage::from_payload(Wireguard {
cmd: WireguardCmd::GetDevice,
nlas: vec![WgDeviceAttrs::IfIndex(if_id)],
});
payload.set_resolved_family_id(family_id);
let mut packet = NetlinkMessage::new(NetlinkHeader::default(), payload.into());
packet.header.flags = NLM_F_DUMP | NLM_F_REQUEST | NLM_F_ACK;
packet.header.sequence_number = 2;

drive_req(packet, &socket, |inner| {
for nla in inner.payload.nlas {
if let WgDeviceAttrs::ListenPort(port) = nla {
if let Some(other_id) = ret.insert(port, if_id) {
eprintln!("WARNING: Wireguard interfaces {if_id} and {other_id} seem to be listening on the same port {port}. Output may be inaccurate");
}
}
}
})
.context("Get wireguard ")?;
}

Ok(ret)
}
Loading

0 comments on commit 5668adc

Please sign in to comment.