Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send any headers already queued in the plug when upgrading websocket #458

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/bandit/websocket/handshake.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ defmodule Bandit.WebSocket.Handshake do
{:connection, "Upgrade"},
{:"sec-websocket-accept", server_key}
] ++
websocket_extension_header(extensions)
websocket_extension_header(extensions) ++
conn.resp_headers

inform(conn, 101, headers)
end
Expand Down
36 changes: 27 additions & 9 deletions test/bandit/websocket/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ defmodule WebSocketProtocolTest do
end

describe "compressed frames" do
test "negotiates compression if globally configured to", context do
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok,
[
"sec-websocket-extensions: permessage-deflate",
"cache-control: max-age=0, private, must-revalidate"
]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

deflated_payload = <<74, 76, 28, 5, 163, 96, 20, 12, 119, 0, 0>>
SimpleWebSocketClient.send_text_frame(client, deflated_payload, 0xC)

assert SimpleWebSocketClient.recv_deflated_text_frame(client) == {:ok, deflated_payload}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok, ["cache-control: max-age=0, private, must-revalidate"]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end

test "correctly decompresses text frames and sends compressed frames back", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)
Expand Down Expand Up @@ -212,15 +239,6 @@ defmodule WebSocketProtocolTest do
assert SimpleWebSocketClient.recv_pong_frame(client) == {:ok, "OK"}
assert SimpleWebSocketClient.recv_ping_frame(client) == {:ok, "OK"}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)
assert {:ok, false} = SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end
end

describe "ping frames" do
Expand Down
28 changes: 15 additions & 13 deletions test/bandit/websocket/upgrade_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ defmodule WebSocketUpgradeTest do
timeout -> [timeout: String.to_integer(timeout)]
end

Plug.Conn.upgrade_adapter(conn, :websocket, {websock, :upgrade, connection_opts})
conn
|> put_resp_header("x-plug-set-header", "itsaheader")
|> Plug.Conn.upgrade_adapter(:websocket, {websock, :upgrade, connection_opts})
end

defmodule UpgradeWebSock do
use NoopWebSock
def init(opts), do: {:ok, opts}
end

defmodule UpgradeSendOnTerminateWebSock do
use NoopWebSock
def init(opts), do: {:ok, [opts, :init]}
def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state}
Expand All @@ -36,7 +43,7 @@ defmodule WebSocketUpgradeTest do
describe "upgrade support" do
test "upgrades to a {websock, websock_opts, conn_opts} tuple, respecting options", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")
SimpleWebSocketClient.http1_handshake(client, UpgradeSendOnTerminateWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
Expand All @@ -49,19 +56,14 @@ defmodule WebSocketUpgradeTest do
assert_in_delta now, then + 250, 50
end

test "upgrade responses do not include content-encoding headers", context do
test "upgrade responses include headers set from the plug", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
assert result == inspect([:upgrade, :init])

# Ensure that the passed timeout was recognized
then = System.monotonic_time(:millisecond)
assert_receive :timeout, 500
now = System.monotonic_time(:millisecond)
assert_in_delta now, then + 250, 50
assert {:ok,
[
"cache-control: max-age=0, private, must-revalidate",
"x-plug-set-header: itsaheader"
]} = SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock)
end

defmodule MyNoopWebSock do
Expand Down
45 changes: 26 additions & 19 deletions test/support/simple_websocket_client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,32 @@ defmodule SimpleWebSocketClient do
] ++ extension_header
)

# Because we don't want to consume any more than our headers, we can't use SimpleHTTP1Client
{:ok, response} = Transport.recv(client, 164)

[
"HTTP/1.1 101 Switching Protocols",
"date: " <> _date,
"upgrade: websocket",
"connection: Upgrade",
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=",
""
] = String.split(response, "\r\n")

case Transport.recv(client, 2) do
{:ok, "\r\n"} ->
{:ok, false}

{:ok, "se"} when deflate ->
{:ok, "c-websocket-extensions: permessage-deflate\r\n\r\n"} = Transport.recv(client, 46)
{:ok, true}
headers = read_headers(client, [])

["HTTP/1.1 101 Switching Protocols" | headers] = headers
["date: " <> _date | headers] = headers
["upgrade: websocket" | headers] = headers
["connection: Upgrade" | headers] = headers
["sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=" | headers] = headers
{:ok, headers}
end

# Read one line at a time so as to not consume any body
defp read_headers(socket, headers) do
case read_line(socket, <<>>) do
"" -> headers
header -> read_headers(socket, headers ++ [header])
end
end

defp read_line(socket, buffer) do
case Transport.recv(socket, 1) do
{:ok, "\r"} ->
{:ok, "\n"} = Transport.recv(socket, 1)
buffer

{:ok, byte} ->
read_line(socket, buffer <> byte)
end
end

Expand Down
Loading