Skip to content

Commit

Permalink
Send any headers already queued in the plug when upgrading websocket
Browse files Browse the repository at this point in the history
Also improve the forever-tedious WebSocket upgrade test to read by
line instead of by byte count
  • Loading branch information
mtrudel committed Jan 14, 2025
1 parent 177b053 commit 497e154
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 42 deletions.
3 changes: 2 additions & 1 deletion lib/bandit/websocket/handshake.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ defmodule Bandit.WebSocket.Handshake do
{:connection, "Upgrade"},
{:"sec-websocket-accept", server_key}
] ++
websocket_extension_header(extensions)
websocket_extension_header(extensions) ++
conn.resp_headers

inform(conn, 101, headers)
end
Expand Down
36 changes: 27 additions & 9 deletions test/bandit/websocket/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ defmodule WebSocketProtocolTest do
end

describe "compressed frames" do
test "negotiates compression if globally configured to", context do
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok,
[
"sec-websocket-extensions: permessage-deflate",
"cache-control: max-age=0, private, must-revalidate"
]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

deflated_payload = <<74, 76, 28, 5, 163, 96, 20, 12, 119, 0, 0>>
SimpleWebSocketClient.send_text_frame(client, deflated_payload, 0xC)

assert SimpleWebSocketClient.recv_deflated_text_frame(client) == {:ok, deflated_payload}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok, ["cache-control: max-age=0, private, must-revalidate"]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end

test "correctly decompresses text frames and sends compressed frames back", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)
Expand Down Expand Up @@ -212,15 +239,6 @@ defmodule WebSocketProtocolTest do
assert SimpleWebSocketClient.recv_pong_frame(client) == {:ok, "OK"}
assert SimpleWebSocketClient.recv_ping_frame(client) == {:ok, "OK"}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)
assert {:ok, false} = SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end
end

describe "ping frames" do
Expand Down
28 changes: 15 additions & 13 deletions test/bandit/websocket/upgrade_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ defmodule WebSocketUpgradeTest do
timeout -> [timeout: String.to_integer(timeout)]
end

Plug.Conn.upgrade_adapter(conn, :websocket, {websock, :upgrade, connection_opts})
conn
|> put_resp_header("x-plug-set-header", "itsaheader")
|> Plug.Conn.upgrade_adapter(:websocket, {websock, :upgrade, connection_opts})
end

defmodule UpgradeWebSock do
use NoopWebSock
def init(opts), do: {:ok, opts}
end

defmodule UpgradeSendOnTerminateWebSock do
use NoopWebSock
def init(opts), do: {:ok, [opts, :init]}
def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state}
Expand All @@ -36,7 +43,7 @@ defmodule WebSocketUpgradeTest do
describe "upgrade support" do
test "upgrades to a {websock, websock_opts, conn_opts} tuple, respecting options", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")
SimpleWebSocketClient.http1_handshake(client, UpgradeSendOnTerminateWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
Expand All @@ -49,19 +56,14 @@ defmodule WebSocketUpgradeTest do
assert_in_delta now, then + 250, 50
end

test "upgrade responses do not include content-encoding headers", context do
test "upgrade responses include headers set from the plug", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
assert result == inspect([:upgrade, :init])

# Ensure that the passed timeout was recognized
then = System.monotonic_time(:millisecond)
assert_receive :timeout, 500
now = System.monotonic_time(:millisecond)
assert_in_delta now, then + 250, 50
assert {:ok,
[
"cache-control: max-age=0, private, must-revalidate",
"x-plug-set-header: itsaheader"
]} = SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock)
end

defmodule MyNoopWebSock do
Expand Down
45 changes: 26 additions & 19 deletions test/support/simple_websocket_client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,32 @@ defmodule SimpleWebSocketClient do
] ++ extension_header
)

# Because we don't want to consume any more than our headers, we can't use SimpleHTTP1Client
{:ok, response} = Transport.recv(client, 164)

[
"HTTP/1.1 101 Switching Protocols",
"date: " <> _date,
"upgrade: websocket",
"connection: Upgrade",
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=",
""
] = String.split(response, "\r\n")

case Transport.recv(client, 2) do
{:ok, "\r\n"} ->
{:ok, false}

{:ok, "se"} when deflate ->
{:ok, "c-websocket-extensions: permessage-deflate\r\n\r\n"} = Transport.recv(client, 46)
{:ok, true}
headers = read_headers(client, [])

["HTTP/1.1 101 Switching Protocols" | headers] = headers
["date: " <> _date | headers] = headers
["upgrade: websocket" | headers] = headers
["connection: Upgrade" | headers] = headers
["sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=" | headers] = headers
{:ok, headers}
end

# Read one line at a time so as to not consume any body
defp read_headers(socket, headers) do
case read_line(socket, <<>>) do
"" -> headers
header -> read_headers(socket, headers ++ [header])
end
end

defp read_line(socket, buffer) do
case Transport.recv(socket, 1) do
{:ok, "\r"} ->
{:ok, "\n"} = Transport.recv(socket, 1)
buffer

{:ok, byte} ->
read_line(socket, buffer <> byte)
end
end

Expand Down

0 comments on commit 497e154

Please sign in to comment.