From b469cc6a00aac82a3d4ee412c0387bf32d706fff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ant=C3=B4nio=20Cardoso?= Date: Wed, 10 Jan 2024 16:37:50 -0300 Subject: [PATCH] src: connection: Fix panic when DNS lookup fails --- src/connection/mod.rs | 16 ++++++++++++++++ src/connection/tcp.rs | 19 +++++++------------ src/connection/udp.rs | 26 ++++++++------------------ 3 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 10dacbe976..7aa4ab78ec 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -102,3 +102,19 @@ pub fn connect(address: &str) -> io::Result protocol_err } } + +/// Returns the socket address for the given address. +pub(crate) fn get_socket_addr( + address: T, +) -> Result { + let addr = match address.to_socket_addrs()?.next() { + Some(addr) => addr, + None => { + return Err(io::Error::new( + io::ErrorKind::Other, + "Host address lookup failed", + )); + } + }; + Ok(addr) +} diff --git a/src/connection/tcp.rs b/src/connection/tcp.rs index 89a109e86e..1d21da12d3 100644 --- a/src/connection/tcp.rs +++ b/src/connection/tcp.rs @@ -6,6 +6,8 @@ use std::net::{TcpListener, TcpStream}; use std::sync::Mutex; use std::time::Duration; +use super::get_socket_addr; + /// TCP MAVLink connection pub fn select_protocol( @@ -24,12 +26,9 @@ pub fn select_protocol( } pub fn tcpout(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Host address lookup failed."); - let socket = TcpStream::connect(&addr)?; + let addr = get_socket_addr(address)?; + + let socket = TcpStream::connect(addr)?; socket.set_read_timeout(Some(Duration::from_millis(100)))?; Ok(TcpConnection { @@ -43,12 +42,8 @@ pub fn tcpout(address: T) -> io::Result { } pub fn tcpin(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); - let listener = TcpListener::bind(&addr)?; + let addr = get_socket_addr(address)?; + let listener = TcpListener::bind(addr)?; //For now we only accept one incoming stream: this blocks until we get one for incoming in listener.incoming() { diff --git a/src/connection/udp.rs b/src/connection/udp.rs index e9418e9a06..31ffaf1c11 100644 --- a/src/connection/udp.rs +++ b/src/connection/udp.rs @@ -7,6 +7,8 @@ use std::net::{SocketAddr, UdpSocket}; use std::str::FromStr; use std::sync::Mutex; +use super::get_socket_addr; + /// UDP MAVLink connection pub fn select_protocol( @@ -27,12 +29,8 @@ pub fn select_protocol( } pub fn udpbcast(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); - let socket = UdpSocket::bind(&SocketAddr::from_str("0.0.0.0:0").unwrap()).unwrap(); + let addr = get_socket_addr(address)?; + let socket = UdpSocket::bind("0.0.0.0:0")?; socket .set_broadcast(true) .expect("Couldn't bind to broadcast address."); @@ -40,22 +38,14 @@ pub fn udpbcast(address: T) -> io::Result { } pub fn udpout(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); - let socket = UdpSocket::bind(&SocketAddr::from_str("0.0.0.0:0").unwrap())?; + let addr = get_socket_addr(address)?; + let socket = UdpSocket::bind("0.0.0.0:0")?; UdpConnection::new(socket, false, Some(addr)) } pub fn udpin(address: T) -> io::Result { - let addr = address - .to_socket_addrs() - .unwrap() - .next() - .expect("Invalid address"); - let socket = UdpSocket::bind(&addr)?; + let addr = get_socket_addr(address)?; + let socket = UdpSocket::bind(addr)?; UdpConnection::new(socket, true, None) }