Skip to content

Commit

Permalink
enhancement(websocket sink): Allow setting HTTP authorization header (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
wez470 authored Aug 9, 2022
1 parent a47dc19 commit 20d69b7
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 27 deletions.
7 changes: 6 additions & 1 deletion src/sinks/websocket/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use vector_config::configurable_component;
use crate::{
codecs::EncodingConfig,
config::{AcknowledgementsConfig, GenerateConfig, Input, SinkConfig, SinkContext},
http::Auth,
sinks::{
websocket::sink::{ConnectSnafu, WebSocketConnector, WebSocketError, WebSocketSink},
Healthcheck, VectorSink,
Expand Down Expand Up @@ -42,6 +43,9 @@ pub struct WebSocketSinkConfig {
skip_serializing_if = "crate::serde::skip_serializing_if_default"
)]
pub acknowledgements: AcknowledgementsConfig,

#[configurable(derived)]
pub auth: Option<Auth>,
}

impl GenerateConfig for WebSocketSinkConfig {
Expand All @@ -53,6 +57,7 @@ impl GenerateConfig for WebSocketSinkConfig {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
})
.unwrap()
}
Expand Down Expand Up @@ -87,7 +92,7 @@ impl SinkConfig for WebSocketSinkConfig {
impl WebSocketSinkConfig {
fn build_connector(&self) -> Result<WebSocketConnector, WebSocketError> {
let tls = MaybeTlsSettings::from_config(&self.tls, false).context(ConnectSnafu)?;
WebSocketConnector::new(self.uri.clone(), tls)
WebSocketConnector::new(self.uri.clone(), tls, self.auth.clone())
}
}

Expand Down
126 changes: 100 additions & 26 deletions src/sinks/websocket/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::{
codecs::{Encoder, Transformer},
dns, emit,
event::{Event, EventStatus, Finalizable},
http::Auth,
internal_events::{
ConnectionOpen, OpenGauge, WsConnectionError, WsConnectionEstablished,
WsConnectionFailedError, WsConnectionShutdown,
Expand Down Expand Up @@ -66,10 +67,15 @@ pub struct WebSocketConnector {
host: String,
port: u16,
tls: MaybeTlsSettings,
auth: Option<Auth>,
}

impl WebSocketConnector {
pub fn new(uri: String, tls: MaybeTlsSettings) -> Result<Self, WebSocketError> {
pub fn new(
uri: String,
tls: MaybeTlsSettings,
auth: Option<Auth>,
) -> Result<Self, WebSocketError> {
let request = (&uri).into_client_request().context(CreateFailedSnafu)?;
let (host, port) = Self::extract_host_and_port(&request).context(CreateFailedSnafu)?;

Expand All @@ -78,6 +84,7 @@ impl WebSocketConnector {
host,
port,
tls,
auth,
})
}

Expand Down Expand Up @@ -118,9 +125,14 @@ impl WebSocketConnector {
}

async fn connect(&self) -> Result<WsStream<MaybeTlsStream<TcpStream>>, WebSocketError> {
let request = (&self.uri)
let mut request = (&self.uri)
.into_client_request()
.context(CreateFailedSnafu)?;

if let Some(auth) = &self.auth {
auth.apply(&mut request);
}

let maybe_tls = self.tls_connect().await?;

let ws_config = WebSocketConfig {
Expand Down Expand Up @@ -372,8 +384,9 @@ mod tests {
use serde_json::Value as JsonValue;
use tokio::time::timeout;
use tokio_tungstenite::{
accept_async,
accept_async, accept_hdr_async,
tungstenite::error::{Error as WsError, ProtocolError},
tungstenite::handshake::server::{Request, Response},
};

use super::*;
Expand All @@ -398,10 +411,34 @@ mod tests {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
};
let tls = MaybeTlsSettings::Raw(());

send_events_and_assert(addr, config, tls).await;
send_events_and_assert(addr, config, tls, None).await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_auth_websocket() {
trace_init();

let auth = Some(Auth::Bearer {
token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string(),
});
let auth_clone = auth.clone();
let addr = next_addr();
let config = WebSocketSinkConfig {
uri: format!("ws://{}", addr),
tls: None,
encoding: JsonSerializerConfig::new().into(),
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth,
};
let tls = MaybeTlsSettings::Raw(());

send_events_and_assert(addr, config, tls, auth_clone).await;
}

#[tokio::test(flavor = "multi_thread")]
Expand All @@ -427,9 +464,10 @@ mod tests {
ping_timeout: None,
ping_interval: None,
acknowledgements: Default::default(),
auth: None,
};

send_events_and_assert(addr, config, tls).await;
send_events_and_assert(addr, config, tls, None).await;
}

#[tokio::test]
Expand All @@ -444,10 +482,11 @@ mod tests {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
};
let tls = MaybeTlsSettings::Raw(());

let mut receiver = create_count_receiver(addr, tls.clone(), true);
let mut receiver = create_count_receiver(addr, tls.clone(), true, None);

let context = SinkContext::new_test();
let (sink, _healthcheck) = config.build(context).await.unwrap();
Expand All @@ -463,7 +502,7 @@ mod tests {
time::sleep(Duration::from_millis(500)).await;
assert!(!receiver.await.is_empty());

let mut receiver = create_count_receiver(addr, tls, false);
let mut receiver = create_count_receiver(addr, tls, false, None);
assert!(timeout(Duration::from_secs(10), receiver.connected())
.await
.is_ok());
Expand All @@ -473,8 +512,9 @@ mod tests {
addr: SocketAddr,
config: WebSocketSinkConfig,
tls: MaybeTlsSettings,
auth: Option<Auth>,
) {
let mut receiver = create_count_receiver(addr, tls, false);
let mut receiver = create_count_receiver(addr, tls, false, auth);

let context = SinkContext::new_test();
let (sink, _healthcheck) = config.build(context).await.unwrap();
Expand All @@ -498,6 +538,7 @@ mod tests {
addr: SocketAddr,
tls: MaybeTlsSettings,
interrupt_stream: bool,
auth: Option<Auth>,
) -> CountReceiver<String> {
CountReceiver::receive_items_stream(move |tripwire, connected| async move {
let listener = tls.bind(&addr).await.unwrap();
Expand All @@ -509,25 +550,58 @@ mod tests {

let stream = stream
.take_until(tripwire)
.filter_map(|maybe_tls_stream| async move {
let maybe_tls_stream = maybe_tls_stream.unwrap();
let ws_stream = accept_async(maybe_tls_stream).await.unwrap();

Some(
ws_stream
.filter_map(|msg| {
future::ready(match msg {
Ok(msg) if msg.is_text() => Some(Ok(msg.into_text().unwrap())),
Err(WsError::Protocol(
ProtocolError::ResetWithoutClosingHandshake,
)) => None,
Err(e) => Some(Err(e)),
_ => None,
.filter_map(move |maybe_tls_stream| {
let au = auth.clone();
async move {
let maybe_tls_stream = maybe_tls_stream.unwrap();
let ws_stream = match au {
Some(a) => {
let auth_callback = |req: &Request, res: Response| {
let hdr = req.headers().get("Authorization");
if let Some(h) = hdr {
match a {
Auth::Bearer { token } => {
if format!("Bearer {}", token)
!= h.to_str().unwrap()
{
return Err(
http::Response::<Option<String>>::new(None),
);
}
}
Auth::Basic {
user: _user,
password: _password,
} => { /* Not needed for tests at the moment */ }
}
}
Ok(res)
};
accept_hdr_async(maybe_tls_stream, auth_callback)
.await
.unwrap()
}
None => accept_async(maybe_tls_stream).await.unwrap(),
};

Some(
ws_stream
.filter_map(|msg| {
future::ready(match msg {
Ok(msg) if msg.is_text() => {
Some(Ok(msg.into_text().unwrap()))
}
Err(WsError::Protocol(
ProtocolError::ResetWithoutClosingHandshake,
)) => None,
Err(e) => Some(Err(e)),
_ => None,
})
})
})
.take_while(|msg| future::ready(msg.is_ok()))
.filter_map(|msg| future::ready(msg.ok())),
)
.take_while(|msg| future::ready(msg.is_ok()))
.filter_map(|msg| future::ready(msg.ok())),
)
}
})
.map(move |ws_stream| {
connected.take().map(|trigger| trigger.send(()));
Expand Down
4 changes: 4 additions & 0 deletions website/cue/reference/components/sinks/websocket.cue
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ components: sinks: websocket: {
}

configuration: {
auth: configuration._http_auth & {_args: {
password_example: "${HTTP_PASSWORD}"
username_example: "${HTTP_USERNAME}"
}}
uri: {
description: """
The WebSocket URI to connect to. This should include the protocol and host,
Expand Down

0 comments on commit 20d69b7

Please sign in to comment.