Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement(websocket sink): Allow setting HTTP authorization header #13632

Merged
merged 8 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sinks/websocket/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use vector_config::configurable_component;
use crate::{
codecs::EncodingConfig,
config::{AcknowledgementsConfig, GenerateConfig, Input, SinkConfig, SinkContext},
http::Auth,
sinks::{
websocket::sink::{ConnectSnafu, WebSocketConnector, WebSocketError, WebSocketSink},
Healthcheck, VectorSink,
Expand Down Expand Up @@ -42,6 +43,9 @@ pub struct WebSocketSinkConfig {
skip_serializing_if = "crate::serde::skip_serializing_if_default"
)]
pub acknowledgements: AcknowledgementsConfig,

#[configurable(derived)]
pub auth: Option<Auth>,
}

impl GenerateConfig for WebSocketSinkConfig {
Expand All @@ -53,6 +57,7 @@ impl GenerateConfig for WebSocketSinkConfig {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
})
.unwrap()
}
Expand Down Expand Up @@ -87,7 +92,7 @@ impl SinkConfig for WebSocketSinkConfig {
impl WebSocketSinkConfig {
fn build_connector(&self) -> Result<WebSocketConnector, WebSocketError> {
let tls = MaybeTlsSettings::from_config(&self.tls, false).context(ConnectSnafu)?;
WebSocketConnector::new(self.uri.clone(), tls)
WebSocketConnector::new(self.uri.clone(), tls, self.auth.clone())
}
}

Expand Down
126 changes: 100 additions & 26 deletions src/sinks/websocket/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::{
codecs::{Encoder, Transformer},
dns, emit,
event::{Event, EventStatus, Finalizable},
http::Auth,
internal_events::{
ConnectionOpen, OpenGauge, WsConnectionError, WsConnectionEstablished,
WsConnectionFailedError, WsConnectionShutdown,
Expand Down Expand Up @@ -66,10 +67,15 @@ pub struct WebSocketConnector {
host: String,
port: u16,
tls: MaybeTlsSettings,
auth: Option<Auth>,
}

impl WebSocketConnector {
pub fn new(uri: String, tls: MaybeTlsSettings) -> Result<Self, WebSocketError> {
pub fn new(
uri: String,
tls: MaybeTlsSettings,
auth: Option<Auth>,
) -> Result<Self, WebSocketError> {
let request = (&uri).into_client_request().context(CreateFailedSnafu)?;
let (host, port) = Self::extract_host_and_port(&request).context(CreateFailedSnafu)?;

Expand All @@ -78,6 +84,7 @@ impl WebSocketConnector {
host,
port,
tls,
auth,
})
}

Expand Down Expand Up @@ -118,9 +125,14 @@ impl WebSocketConnector {
}

async fn connect(&self) -> Result<WsStream<MaybeTlsStream<TcpStream>>, WebSocketError> {
let request = (&self.uri)
let mut request = (&self.uri)
.into_client_request()
.context(CreateFailedSnafu)?;

if let Some(auth) = &self.auth {
auth.apply(&mut request);
}

let maybe_tls = self.tls_connect().await?;

let ws_config = WebSocketConfig {
Expand Down Expand Up @@ -372,8 +384,9 @@ mod tests {
use serde_json::Value as JsonValue;
use tokio::time::timeout;
use tokio_tungstenite::{
accept_async,
accept_async, accept_hdr_async,
tungstenite::error::{Error as WsError, ProtocolError},
tungstenite::handshake::server::{Request, Response},
};

use super::*;
Expand All @@ -398,10 +411,34 @@ mod tests {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
};
let tls = MaybeTlsSettings::Raw(());

send_events_and_assert(addr, config, tls).await;
send_events_and_assert(addr, config, tls, None).await;
}

#[tokio::test(flavor = "multi_thread")]
async fn test_auth_websocket() {
trace_init();

let auth = Some(Auth::Bearer {
token: "OiJIUzI1NiIsInR5cCI6IkpXVCJ".to_string(),
});
let auth_clone = auth.clone();
let addr = next_addr();
let config = WebSocketSinkConfig {
uri: format!("ws://{}", addr),
tls: None,
encoding: JsonSerializerConfig::new().into(),
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth,
};
let tls = MaybeTlsSettings::Raw(());

send_events_and_assert(addr, config, tls, auth_clone).await;
}

#[tokio::test(flavor = "multi_thread")]
Expand All @@ -427,9 +464,10 @@ mod tests {
ping_timeout: None,
ping_interval: None,
acknowledgements: Default::default(),
auth: None,
};

send_events_and_assert(addr, config, tls).await;
send_events_and_assert(addr, config, tls, None).await;
}

#[tokio::test]
Expand All @@ -444,10 +482,11 @@ mod tests {
ping_interval: None,
ping_timeout: None,
acknowledgements: Default::default(),
auth: None,
};
let tls = MaybeTlsSettings::Raw(());

let mut receiver = create_count_receiver(addr, tls.clone(), true);
let mut receiver = create_count_receiver(addr, tls.clone(), true, None);

let context = SinkContext::new_test();
let (sink, _healthcheck) = config.build(context).await.unwrap();
Expand All @@ -463,7 +502,7 @@ mod tests {
time::sleep(Duration::from_millis(500)).await;
assert!(!receiver.await.is_empty());

let mut receiver = create_count_receiver(addr, tls, false);
let mut receiver = create_count_receiver(addr, tls, false, None);
assert!(timeout(Duration::from_secs(10), receiver.connected())
.await
.is_ok());
Expand All @@ -473,8 +512,9 @@ mod tests {
addr: SocketAddr,
config: WebSocketSinkConfig,
tls: MaybeTlsSettings,
auth: Option<Auth>,
) {
let mut receiver = create_count_receiver(addr, tls, false);
let mut receiver = create_count_receiver(addr, tls, false, auth);

let context = SinkContext::new_test();
let (sink, _healthcheck) = config.build(context).await.unwrap();
Expand All @@ -498,6 +538,7 @@ mod tests {
addr: SocketAddr,
tls: MaybeTlsSettings,
interrupt_stream: bool,
auth: Option<Auth>,
) -> CountReceiver<String> {
CountReceiver::receive_items_stream(move |tripwire, connected| async move {
let listener = tls.bind(&addr).await.unwrap();
Expand All @@ -509,25 +550,58 @@ mod tests {

let stream = stream
.take_until(tripwire)
.filter_map(|maybe_tls_stream| async move {
let maybe_tls_stream = maybe_tls_stream.unwrap();
let ws_stream = accept_async(maybe_tls_stream).await.unwrap();

Some(
ws_stream
.filter_map(|msg| {
future::ready(match msg {
Ok(msg) if msg.is_text() => Some(Ok(msg.into_text().unwrap())),
Err(WsError::Protocol(
ProtocolError::ResetWithoutClosingHandshake,
)) => None,
Err(e) => Some(Err(e)),
_ => None,
.filter_map(move |maybe_tls_stream| {
let au = auth.clone();
async move {
let maybe_tls_stream = maybe_tls_stream.unwrap();
let ws_stream = match au {
Some(a) => {
let auth_callback = |req: &Request, res: Response| {
let hdr = req.headers().get("Authorization");
if let Some(h) = hdr {
match a {
Auth::Bearer { token } => {
if format!("Bearer {}", token)
!= h.to_str().unwrap()
{
return Err(
http::Response::<Option<String>>::new(None),
);
}
}
Auth::Basic {
user: _user,
password: _password,
} => { /* Not needed for tests at the moment */ }
}
}
Ok(res)
};
accept_hdr_async(maybe_tls_stream, auth_callback)
.await
.unwrap()
}
None => accept_async(maybe_tls_stream).await.unwrap(),
};

Some(
ws_stream
.filter_map(|msg| {
future::ready(match msg {
Ok(msg) if msg.is_text() => {
Some(Ok(msg.into_text().unwrap()))
}
Err(WsError::Protocol(
ProtocolError::ResetWithoutClosingHandshake,
)) => None,
Err(e) => Some(Err(e)),
_ => None,
})
})
})
.take_while(|msg| future::ready(msg.is_ok()))
.filter_map(|msg| future::ready(msg.ok())),
)
.take_while(|msg| future::ready(msg.is_ok()))
.filter_map(|msg| future::ready(msg.ok())),
)
}
})
.map(move |ws_stream| {
connected.take().map(|trigger| trigger.send(()));
Expand Down
4 changes: 4 additions & 0 deletions website/cue/reference/components/sinks/websocket.cue
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ components: sinks: websocket: {
}

configuration: {
auth: configuration._http_auth & {_args: {
password_example: "${HTTP_PASSWORD}"
username_example: "${HTTP_USERNAME}"
}}
uri: {
description: """
The WebSocket URI to connect to. This should include the protocol and host,
Expand Down