Skip to content

Commit

Permalink
Send any headers already queued in the plug when upgrading websocket (#…
Browse files Browse the repository at this point in the history
…458)

Also improve the forever-tedious WebSocket upgrade test to read by
line instead of by byte count
  • Loading branch information
mtrudel authored Jan 15, 2025
1 parent 1ddeb34 commit 41bb3d8
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 42 deletions.
3 changes: 2 additions & 1 deletion lib/bandit/websocket/handshake.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ defmodule Bandit.WebSocket.Handshake do
{:connection, "Upgrade"},
{:"sec-websocket-accept", server_key}
] ++
websocket_extension_header(extensions)
websocket_extension_header(extensions) ++
conn.resp_headers

inform(conn, 101, headers)
end
Expand Down
36 changes: 27 additions & 9 deletions test/bandit/websocket/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,33 @@ defmodule WebSocketProtocolTest do
end

describe "compressed frames" do
test "negotiates compression if globally configured to", context do
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok,
[
"sec-websocket-extensions: permessage-deflate",
"cache-control: max-age=0, private, must-revalidate"
]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

deflated_payload = <<74, 76, 28, 5, 163, 96, 20, 12, 119, 0, 0>>
SimpleWebSocketClient.send_text_frame(client, deflated_payload, 0xC)

assert SimpleWebSocketClient.recv_deflated_text_frame(client) == {:ok, deflated_payload}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)

assert {:ok, ["cache-control: max-age=0, private, must-revalidate"]} =
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end

test "correctly decompresses text frames and sends compressed frames back", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)
Expand Down Expand Up @@ -212,15 +239,6 @@ defmodule WebSocketProtocolTest do
assert SimpleWebSocketClient.recv_pong_frame(client) == {:ok, "OK"}
assert SimpleWebSocketClient.recv_ping_frame(client) == {:ok, "OK"}
end

test "does not negotiate compression if not globally configured to", context do
context = http_server(context, websocket_options: [compress: false])
client = SimpleWebSocketClient.tcp_client(context)
assert {:ok, false} = SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true)

SimpleWebSocketClient.send_text_frame(client, "OK")
assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"}
end
end

describe "ping frames" do
Expand Down
28 changes: 15 additions & 13 deletions test/bandit/websocket/upgrade_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@ defmodule WebSocketUpgradeTest do
timeout -> [timeout: String.to_integer(timeout)]
end

Plug.Conn.upgrade_adapter(conn, :websocket, {websock, :upgrade, connection_opts})
conn
|> put_resp_header("x-plug-set-header", "itsaheader")
|> Plug.Conn.upgrade_adapter(:websocket, {websock, :upgrade, connection_opts})
end

defmodule UpgradeWebSock do
use NoopWebSock
def init(opts), do: {:ok, opts}
end

defmodule UpgradeSendOnTerminateWebSock do
use NoopWebSock
def init(opts), do: {:ok, [opts, :init]}
def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state}
Expand All @@ -36,7 +43,7 @@ defmodule WebSocketUpgradeTest do
describe "upgrade support" do
test "upgrades to a {websock, websock_opts, conn_opts} tuple, respecting options", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")
SimpleWebSocketClient.http1_handshake(client, UpgradeSendOnTerminateWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
Expand All @@ -49,19 +56,14 @@ defmodule WebSocketUpgradeTest do
assert_in_delta now, then + 250, 50
end

test "upgrade responses do not include content-encoding headers", context do
test "upgrade responses include headers set from the plug", context do
client = SimpleWebSocketClient.tcp_client(context)
SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250")

SimpleWebSocketClient.send_text_frame(client, "")
{:ok, result} = SimpleWebSocketClient.recv_text_frame(client)
assert result == inspect([:upgrade, :init])

# Ensure that the passed timeout was recognized
then = System.monotonic_time(:millisecond)
assert_receive :timeout, 500
now = System.monotonic_time(:millisecond)
assert_in_delta now, then + 250, 50
assert {:ok,
[
"cache-control: max-age=0, private, must-revalidate",
"x-plug-set-header: itsaheader"
]} = SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock)
end

defmodule MyNoopWebSock do
Expand Down
45 changes: 26 additions & 19 deletions test/support/simple_websocket_client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,32 @@ defmodule SimpleWebSocketClient do
] ++ extension_header
)

# Because we don't want to consume any more than our headers, we can't use SimpleHTTP1Client
{:ok, response} = Transport.recv(client, 164)

[
"HTTP/1.1 101 Switching Protocols",
"date: " <> _date,
"upgrade: websocket",
"connection: Upgrade",
"sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=",
""
] = String.split(response, "\r\n")

case Transport.recv(client, 2) do
{:ok, "\r\n"} ->
{:ok, false}

{:ok, "se"} when deflate ->
{:ok, "c-websocket-extensions: permessage-deflate\r\n\r\n"} = Transport.recv(client, 46)
{:ok, true}
headers = read_headers(client, [])

["HTTP/1.1 101 Switching Protocols" | headers] = headers
["date: " <> _date | headers] = headers
["upgrade: websocket" | headers] = headers
["connection: Upgrade" | headers] = headers
["sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=" | headers] = headers
{:ok, headers}
end

# Read one line at a time so as to not consume any body
defp read_headers(socket, headers) do
case read_line(socket, <<>>) do
"" -> headers
header -> read_headers(socket, headers ++ [header])
end
end

defp read_line(socket, buffer) do
case Transport.recv(socket, 1) do
{:ok, "\r"} ->
{:ok, "\n"} = Transport.recv(socket, 1)
buffer

{:ok, byte} ->
read_line(socket, buffer <> byte)
end
end

Expand Down

0 comments on commit 41bb3d8

Please sign in to comment.