Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report wireguard endpoint as a candidate when an endpoint override is in place #305

Merged
merged 7 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hostsfile/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ impl HostsBuilder {

let hosts_file = OpenOptions::new()
.create(true)
.truncate(false)
.read(true)
.write(true)
.open(hosts_path)?;
Expand Down
17 changes: 13 additions & 4 deletions server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@ pub mod admin;
pub mod user;

/// Inject the collected endpoints from the WG interface into a list of peers.
/// This is essentially what adds NAT holepunching functionality.
/// This is essentially what adds NAT holepunching functionality. If a peer
/// already has an endpoint specified (by calling the override-endpoint) API,
/// the relatively recent wireguard endpoint will be added to the list of NAT
/// candidates, so other peers have a better chance of connecting.
pub fn inject_endpoints(session: &Session, peers: &mut Vec<Peer>) {
for peer in peers {
if peer.contents.endpoint.is_none() {
if let Some(endpoint) = session.context.endpoints.read().get(&peer.public_key) {
peer.contents.endpoint = Some(endpoint.to_owned().into());
let endpoints = session.context.endpoints.read();
if let Some(wg_endpoint) = endpoints.get(&peer.public_key) {
if peer.contents.endpoint.is_none() {
peer.contents.endpoint = Some(wg_endpoint.to_owned().into());
} else {
// The peer already has an endpoint specified, but it might be stale.
// If there is an endpoint reported from wireguard, we should add it
// to the list of candidates so others can try to connect using it.
peer.contents.candidates.push(wg_endpoint.to_owned().into());
}
}
}
Expand Down
89 changes: 89 additions & 0 deletions server/src/api/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,4 +464,93 @@ mod tests {
assert_eq!(peer.candidates, candidates);
Ok(())
}

#[tokio::test]
async fn test_endpoint_in_candidates() -> Result<(), Error> {
// We want to verify that the current wireguard endpoint always shows up
// either in the peer.endpoint field, or the peer.candidates field (in the
// case that the peer has specified an endpoint override).
let server = test::Server::new()?;

let peer = DatabasePeer::get(&server.db().lock(), test::DEVELOPER1_PEER_ID)?;
assert_eq!(peer.candidates, vec![]);

// Specify one NAT candidate. At this point, we have an unspecified
// endpoint and one NAT candidate specified.
let candidates = vec!["1.1.1.1:51820".parse::<Endpoint>().unwrap()];
assert_eq!(
server
.form_request(
test::DEVELOPER1_PEER_IP,
"PUT",
"/v1/user/candidates",
&candidates
)
.await
.status(),
StatusCode::NO_CONTENT
);

let res = server
.request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state")
.await;

assert_eq!(res.status(), StatusCode::OK);

let whole_body = hyper::body::aggregate(res).await?;
let State { peers, .. } = serde_json::from_reader(whole_body.reader())?;

let developer_1 = peers
.into_iter()
.find(|p| p.id == test::DEVELOPER1_PEER_ID)
.unwrap();
assert_eq!(
developer_1.endpoint,
Some(test::DEVELOPER1_PEER_ENDPOINT.parse().unwrap())
);
assert_eq!(developer_1.candidates, candidates);

// Now, explicitly set an endpoint with the override-endpoint API
// and check that the original wireguard endpoint still shows up
// in the list of NAT candidates.
assert_eq!(
server
.form_request(
test::DEVELOPER1_PEER_IP,
"PUT",
"/v1/user/endpoint",
&EndpointContents::Set("1.2.3.4:51820".parse().unwrap())
)
.await
.status(),
StatusCode::NO_CONTENT
);

let res = server
.request(test::DEVELOPER1_PEER_IP, "GET", "/v1/user/state")
.await;

assert_eq!(res.status(), StatusCode::OK);

let whole_body = hyper::body::aggregate(res).await?;
let State { peers, .. } = serde_json::from_reader(whole_body.reader())?;

let developer_1 = peers
.into_iter()
.find(|p| p.id == test::DEVELOPER1_PEER_ID)
.unwrap();

// The peer endpoint should be the one we just specified in the override-endpoint request.
assert_eq!(developer_1.endpoint, Some("1.2.3.4:51820".parse().unwrap()));

// The list of candidates should now contain the one we specified at the beginning of the
// test, and the wireguard-reported endpoint of the peer.
let nat_candidate_1 = "1.1.1.1:51820".parse().unwrap();
let nat_candidate_2 = test::DEVELOPER1_PEER_ENDPOINT.parse().unwrap();
assert_eq!(developer_1.candidates.len(), 2);
assert!(developer_1.candidates.contains(&nat_candidate_1));
assert!(developer_1.candidates.contains(&nat_candidate_2));

Ok(())
}
}
2 changes: 1 addition & 1 deletion server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ pub type Endpoints = Arc<RwLock<HashMap<String, SocketAddr>>>;
#[derive(Clone)]
pub struct Context {
pub db: Db,
pub endpoints: Arc<RwLock<HashMap<String, SocketAddr>>>,
pub endpoints: Endpoints,
pub interface: InterfaceName,
pub backend: Backend,
pub public_key: Key,
Expand Down
37 changes: 25 additions & 12 deletions server/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use parking_lot::{Mutex, RwLock};
use rusqlite::Connection;
use serde::Serialize;
use shared::{Cidr, CidrContents, Error, PeerContents};
use std::{collections::HashMap, net::SocketAddr, path::PathBuf, sync::Arc};
use std::{net::SocketAddr, path::PathBuf, sync::Arc};
use tempfile::TempDir;
use wireguard_control::{Backend, InterfaceName, Key, KeyPair};

Expand All @@ -27,7 +27,9 @@ mod v4 {
pub const ADMIN_PEER_IP: &str = "10.80.1.1";
pub const WG_MANAGE_PEER_IP: &str = ADMIN_PEER_IP;
pub const DEVELOPER1_PEER_IP: &str = "10.80.64.2";
pub const DEVELOPER1_PEER_ENDPOINT: &str = "169.10.26.8:14720";
pub const DEVELOPER2_PEER_IP: &str = "10.80.64.3";
pub const DEVELOPER2_PEER_ENDPOINT: &str = "169.55.140.9:5833";
pub const USER1_PEER_IP: &str = "10.80.128.2";
pub const USER2_PEER_IP: &str = "10.80.129.2";
pub const EXPERIMENT_SUBCIDR_PEER_IP: &str = "10.81.0.1";
Expand All @@ -48,7 +50,9 @@ mod v6 {
pub const ADMIN_PEER_IP: &str = "fd00:1337::1:0:0:1";
pub const WG_MANAGE_PEER_IP: &str = ADMIN_PEER_IP;
pub const DEVELOPER1_PEER_IP: &str = "fd00:1337::2:0:0:1";
pub const DEVELOPER1_PEER_ENDPOINT: &str = "[1001:db8::1]:14720";
pub const DEVELOPER2_PEER_IP: &str = "fd00:1337::2:0:0:2";
pub const DEVELOPER2_PEER_ENDPOINT: &str = "[2001:db8::1]:5833";
pub const USER1_PEER_IP: &str = "fd00:1337::3:0:0:1";
pub const USER2_PEER_IP: &str = "fd00:1337::3:0:0:2";
pub const EXPERIMENT_SUBCIDR_PEER_IP: &str = "fd00:1337::4:0:0:1";
Expand Down Expand Up @@ -114,21 +118,19 @@ impl Server {
DEVELOPER_CIDR_ID,
create_cidr(&db, "developer", DEVELOPER_CIDR)?.id
);

let developer_1 = developer_peer_contents("developer1", DEVELOPER1_PEER_IP)?;
let developer_1_public_key = developer_1.public_key.clone();
assert_eq!(
DEVELOPER1_PEER_ID,
DatabasePeer::create(
&db,
developer_peer_contents("developer1", DEVELOPER1_PEER_IP)?
)?
.id
DatabasePeer::create(&db, developer_1,)?.id
);

let developer_2 = developer_peer_contents("developer2", DEVELOPER2_PEER_IP)?;
let developer_2_public_key = developer_2.public_key.clone();
assert_eq!(
DEVELOPER2_PEER_ID,
DatabasePeer::create(
&db,
developer_peer_contents("developer2", DEVELOPER2_PEER_IP)?
)?
.id
DatabasePeer::create(&db, developer_2)?.id
);
assert_eq!(USER_CIDR_ID, create_cidr(&db, "user", USER_CIDR)?.id);
assert_eq!(
Expand All @@ -141,7 +143,18 @@ impl Server {
);

let db = Arc::new(Mutex::new(db));
let endpoints = Arc::new(RwLock::new(HashMap::new()));

let endpoints = [
(
developer_1_public_key,
DEVELOPER1_PEER_ENDPOINT.parse().unwrap(),
),
(
developer_2_public_key,
DEVELOPER2_PEER_ENDPOINT.parse().unwrap(),
),
];
let endpoints = Arc::new(RwLock::new(endpoints.into()));

Ok(Self {
conf,
Expand Down
Loading