Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for running optimized native code for websocket mask #394

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/bandit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ defmodule Bandit do
@http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a
@http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a
@http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress primitive_ops_module)a
@thousand_island_keys ThousandIsland.ServerConfig.__struct__()
|> Map.from_struct()
|> Map.keys()
Expand Down
17 changes: 10 additions & 7 deletions lib/bandit/extractor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ defmodule Bandit.Extractor do
| {:error, term()}
| :more

@callback deserialize(binary()) :: deserialize_result()
@callback deserialize(binary(), primitive_ops_module :: module()) :: deserialize_result()

@type t :: %__MODULE__{
header: binary(),
Expand All @@ -18,7 +18,8 @@ defmodule Bandit.Extractor do
required_length: non_neg_integer(),
mode: :header_parsing | :payload_parsing,
max_frame_size: non_neg_integer(),
frame_parser: atom()
frame_parser: atom(),
primitive_ops_module: module()
}

defstruct header: <<>>,
Expand All @@ -27,15 +28,17 @@ defmodule Bandit.Extractor do
required_length: 0,
mode: :header_parsing,
max_frame_size: 0,
frame_parser: nil
frame_parser: nil,
primitive_ops_module: nil

@spec new(module(), Keyword.t()) :: t()
def new(frame_parser, opts) do
@spec new(module(), module(), Keyword.t()) :: t()
def new(frame_parser, primitive_ops_module, opts) do
max_frame_size = Keyword.get(opts, :max_frame_size, 0)

%__MODULE__{
max_frame_size: max_frame_size,
frame_parser: frame_parser
frame_parser: frame_parser,
primitive_ops_module: primitive_ops_module
}
end

Expand Down Expand Up @@ -79,7 +82,7 @@ defmodule Bandit.Extractor do
<<payload::binary-size(required_length), rest::binary>> =
IO.iodata_to_binary(state.payload)

frame = state.frame_parser.deserialize(state.header <> payload)
frame = state.frame_parser.deserialize(state.header <> payload, state.primitive_ops_module)
state = transition_to_header_parsing(state, rest)

{state, frame}
Expand Down
34 changes: 34 additions & 0 deletions lib/bandit/primitive_ops/websocket.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
defmodule Bandit.PrimitiveOps.WebSocket do
@moduledoc """
WebSocket primitive operations behaviour and default implementation
"""

@doc """
WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3)
"""
@callback ws_mask(payload :: binary(), mask :: integer()) :: binary()

@behaviour __MODULE__

# Note that masking is an involution, so we don't need a separate unmask function
@impl true
def ws_mask(payload, mask)
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
ws_mask(<<>>, payload, mask)
end

defp ws_mask(acc, <<h::32, rest::binary>>, mask) do
ws_mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
end

for size <- [24, 16, 8] do
defp ws_mask(acc, <<h::unquote(size)>>, mask) do
<<mask::unquote(size), _::binary>> = <<mask::32>>
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
end
end

defp ws_mask(acc, <<>>, _mask) do
acc
end
end
48 changes: 15 additions & 33 deletions lib/bandit/websocket/frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,32 @@ defmodule Bandit.WebSocket.Frame do
end

@impl Bandit.Extractor
@spec deserialize(binary()) :: {:ok, frame()} | {:error, term()}
@spec deserialize(binary(), module()) :: {:ok, frame()} | {:error, term()}
def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 127::7, length::64, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 126::7, length::16, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, length::7, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(_msg) do
def deserialize(_msg, _primitive_ops_module) do
{:error, :deserialization_failed}
end

Expand Down Expand Up @@ -155,14 +158,15 @@ defmodule Bandit.WebSocket.Frame do
end
end

defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do
defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, _primitive_ops_module)
when rsv != 0x0 do
{:error, "Received unsupported RSV flags #{rsv}"}
end

defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do
defp to_frame(fin, compressed, 0x0, opcode, mask, payload, primitive_ops_module) do
fin = fin == 0x1
compressed = compressed == 0x1
unmasked_payload = mask(payload, mask)
unmasked_payload = primitive_ops_module.ws_mask(payload, mask)

opcode
|> case do
Expand Down Expand Up @@ -198,26 +202,4 @@ defmodule Bandit.WebSocket.Frame do
defp mask_and_length(length) when length <= 125, do: <<0::1, length::7>>
defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>>
defp mask_and_length(length), do: <<0::1, 127::7, length::64>>

# Note that masking is an involution, so we don't need a separate unmask function
@spec mask(binary(), integer()) :: binary()
def mask(payload, mask)
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
mask(<<>>, payload, mask)
end

defp mask(acc, <<h::32, rest::binary>>, mask) do
mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
end

for size <- [24, 16, 8] do
defp mask(acc, <<h::unquote(size)>>, mask) do
<<mask::unquote(size), _::binary>> = <<mask::32>>
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
end
end

defp mask(acc, <<>>, _mask) do
acc
end
end
5 changes: 4 additions & 1 deletion lib/bandit/websocket/handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ defmodule Bandit.WebSocket.Handler do

connection_opts = Keyword.merge(state.opts.websocket, connection_opts)

primitive_ops_module =
Keyword.get(state.opts.websocket, :primitive_ops_module, Bandit.PrimitiveOps.WebSocket)

state =
state
|> Map.take([:handler_module])
|> Map.put(:extractor, Extractor.new(Frame, connection_opts))
|> Map.put(:extractor, Extractor.new(Frame, primitive_ops_module, connection_opts))

case Connection.init(websock, websock_opts, connection_opts, socket) do
{:continue, connection} ->
Expand Down
Loading
Loading