diff --git a/lib/bandit/websocket/handshake.ex b/lib/bandit/websocket/handshake.ex index ac6794a5..a736162a 100644 --- a/lib/bandit/websocket/handshake.ex +++ b/lib/bandit/websocket/handshake.ex @@ -75,7 +75,8 @@ defmodule Bandit.WebSocket.Handshake do {:connection, "Upgrade"}, {:"sec-websocket-accept", server_key} ] ++ - websocket_extension_header(extensions) + websocket_extension_header(extensions) ++ + conn.resp_headers inform(conn, 101, headers) end diff --git a/test/bandit/websocket/protocol_test.exs b/test/bandit/websocket/protocol_test.exs index 774a3e75..939e2b47 100644 --- a/test/bandit/websocket/protocol_test.exs +++ b/test/bandit/websocket/protocol_test.exs @@ -153,6 +153,33 @@ defmodule WebSocketProtocolTest do end describe "compressed frames" do + test "negotiates compression if globally configured to", context do + client = SimpleWebSocketClient.tcp_client(context) + + assert {:ok, + [ + "sec-websocket-extensions: permessage-deflate", + "cache-control: max-age=0, private, must-revalidate" + ]} = + SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true) + + deflated_payload = <<74, 76, 28, 5, 163, 96, 20, 12, 119, 0, 0>> + SimpleWebSocketClient.send_text_frame(client, deflated_payload, 0xC) + + assert SimpleWebSocketClient.recv_deflated_text_frame(client) == {:ok, deflated_payload} + end + + test "does not negotiate compression if not globally configured to", context do + context = http_server(context, websocket_options: [compress: false]) + client = SimpleWebSocketClient.tcp_client(context) + + assert {:ok, ["cache-control: max-age=0, private, must-revalidate"]} = + SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true) + + SimpleWebSocketClient.send_text_frame(client, "OK") + assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"} + end + test "correctly decompresses text frames and sends compressed frames back", context do client = SimpleWebSocketClient.tcp_client(context) SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true) @@ -212,15 +239,6 @@ defmodule WebSocketProtocolTest do assert SimpleWebSocketClient.recv_pong_frame(client) == {:ok, "OK"} assert SimpleWebSocketClient.recv_ping_frame(client) == {:ok, "OK"} end - - test "does not negotiate compression if not globally configured to", context do - context = http_server(context, websocket_options: [compress: false]) - client = SimpleWebSocketClient.tcp_client(context) - assert {:ok, false} = SimpleWebSocketClient.http1_handshake(client, EchoWebSock, [], true) - - SimpleWebSocketClient.send_text_frame(client, "OK") - assert SimpleWebSocketClient.recv_text_frame(client) == {:ok, "OK"} - end end describe "ping frames" do diff --git a/test/bandit/websocket/upgrade_test.exs b/test/bandit/websocket/upgrade_test.exs index b65f715b..898aa1c1 100644 --- a/test/bandit/websocket/upgrade_test.exs +++ b/test/bandit/websocket/upgrade_test.exs @@ -16,10 +16,17 @@ defmodule WebSocketUpgradeTest do timeout -> [timeout: String.to_integer(timeout)] end - Plug.Conn.upgrade_adapter(conn, :websocket, {websock, :upgrade, connection_opts}) + conn + |> put_resp_header("x-plug-set-header", "itsaheader") + |> Plug.Conn.upgrade_adapter(:websocket, {websock, :upgrade, connection_opts}) end defmodule UpgradeWebSock do + use NoopWebSock + def init(opts), do: {:ok, opts} + end + + defmodule UpgradeSendOnTerminateWebSock do use NoopWebSock def init(opts), do: {:ok, [opts, :init]} def handle_in(_data, state), do: {:push, {:text, inspect(state)}, state} @@ -36,7 +43,7 @@ defmodule WebSocketUpgradeTest do describe "upgrade support" do test "upgrades to a {websock, websock_opts, conn_opts} tuple, respecting options", context do client = SimpleWebSocketClient.tcp_client(context) - SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250") + SimpleWebSocketClient.http1_handshake(client, UpgradeSendOnTerminateWebSock, timeout: "250") SimpleWebSocketClient.send_text_frame(client, "") {:ok, result} = SimpleWebSocketClient.recv_text_frame(client) @@ -49,19 +56,14 @@ defmodule WebSocketUpgradeTest do assert_in_delta now, then + 250, 50 end - test "upgrade responses do not include content-encoding headers", context do + test "upgrade responses include headers set from the plug", context do client = SimpleWebSocketClient.tcp_client(context) - SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock, timeout: "250") - - SimpleWebSocketClient.send_text_frame(client, "") - {:ok, result} = SimpleWebSocketClient.recv_text_frame(client) - assert result == inspect([:upgrade, :init]) - # Ensure that the passed timeout was recognized - then = System.monotonic_time(:millisecond) - assert_receive :timeout, 500 - now = System.monotonic_time(:millisecond) - assert_in_delta now, then + 250, 50 + assert {:ok, + [ + "cache-control: max-age=0, private, must-revalidate", + "x-plug-set-header: itsaheader" + ]} = SimpleWebSocketClient.http1_handshake(client, UpgradeWebSock) end defmodule MyNoopWebSock do diff --git a/test/support/simple_websocket_client.ex b/test/support/simple_websocket_client.ex index 8b01a0a5..fb443772 100644 --- a/test/support/simple_websocket_client.ex +++ b/test/support/simple_websocket_client.ex @@ -23,25 +23,32 @@ defmodule SimpleWebSocketClient do ] ++ extension_header ) - # Because we don't want to consume any more than our headers, we can't use SimpleHTTP1Client - {:ok, response} = Transport.recv(client, 164) - - [ - "HTTP/1.1 101 Switching Protocols", - "date: " <> _date, - "upgrade: websocket", - "connection: Upgrade", - "sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=", - "" - ] = String.split(response, "\r\n") - - case Transport.recv(client, 2) do - {:ok, "\r\n"} -> - {:ok, false} - - {:ok, "se"} when deflate -> - {:ok, "c-websocket-extensions: permessage-deflate\r\n\r\n"} = Transport.recv(client, 46) - {:ok, true} + headers = read_headers(client, []) + + ["HTTP/1.1 101 Switching Protocols" | headers] = headers + ["date: " <> _date | headers] = headers + ["upgrade: websocket" | headers] = headers + ["connection: Upgrade" | headers] = headers + ["sec-websocket-accept: s3pPLMBiTxaQ9kYGzzhZRbK\+xOo=" | headers] = headers + {:ok, headers} + end + + # Read one line at a time so as to not consume any body + defp read_headers(socket, headers) do + case read_line(socket, <<>>) do + "" -> headers + header -> read_headers(socket, headers ++ [header]) + end + end + + defp read_line(socket, buffer) do + case Transport.recv(socket, 1) do + {:ok, "\r"} -> + {:ok, "\n"} = Transport.recv(socket, 1) + buffer + + {:ok, byte} -> + read_line(socket, buffer <> byte) end end