Skip to content

Commit

Permalink
Add support for running optimized native code for websocket mask (#394)
Browse files Browse the repository at this point in the history
* Add PrimitiveOps behaviour and default impl

* Move into websocket options

* Move primitive_ops_module in Extractor.new/3
  • Loading branch information
alisinabh authored Nov 15, 2024
1 parent cd855ec commit 898afdc
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 106 deletions.
2 changes: 1 addition & 1 deletion lib/bandit.ex
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ defmodule Bandit do
@http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a
@http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a
@http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress primitive_ops_module)a
@thousand_island_keys ThousandIsland.ServerConfig.__struct__()
|> Map.from_struct()
|> Map.keys()
Expand Down
17 changes: 10 additions & 7 deletions lib/bandit/extractor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ defmodule Bandit.Extractor do
| {:error, term()}
| :more

@callback deserialize(binary()) :: deserialize_result()
@callback deserialize(binary(), primitive_ops_module :: module()) :: deserialize_result()

@type t :: %__MODULE__{
header: binary(),
Expand All @@ -18,7 +18,8 @@ defmodule Bandit.Extractor do
required_length: non_neg_integer(),
mode: :header_parsing | :payload_parsing,
max_frame_size: non_neg_integer(),
frame_parser: atom()
frame_parser: atom(),
primitive_ops_module: module()
}

defstruct header: <<>>,
Expand All @@ -27,15 +28,17 @@ defmodule Bandit.Extractor do
required_length: 0,
mode: :header_parsing,
max_frame_size: 0,
frame_parser: nil
frame_parser: nil,
primitive_ops_module: nil

@spec new(module(), Keyword.t()) :: t()
def new(frame_parser, opts) do
@spec new(module(), module(), Keyword.t()) :: t()
def new(frame_parser, primitive_ops_module, opts) do
max_frame_size = Keyword.get(opts, :max_frame_size, 0)

%__MODULE__{
max_frame_size: max_frame_size,
frame_parser: frame_parser
frame_parser: frame_parser,
primitive_ops_module: primitive_ops_module
}
end

Expand Down Expand Up @@ -79,7 +82,7 @@ defmodule Bandit.Extractor do
<<payload::binary-size(required_length), rest::binary>> =
IO.iodata_to_binary(state.payload)

frame = state.frame_parser.deserialize(state.header <> payload)
frame = state.frame_parser.deserialize(state.header <> payload, state.primitive_ops_module)
state = transition_to_header_parsing(state, rest)

{state, frame}
Expand Down
34 changes: 34 additions & 0 deletions lib/bandit/primitive_ops/websocket.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
defmodule Bandit.PrimitiveOps.WebSocket do
@moduledoc """
WebSocket primitive operations behaviour and default implementation
"""

@doc """
WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3)
"""
@callback ws_mask(payload :: binary(), mask :: integer()) :: binary()

@behaviour __MODULE__

# Note that masking is an involution, so we don't need a separate unmask function
@impl true
def ws_mask(payload, mask)
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
ws_mask(<<>>, payload, mask)
end

defp ws_mask(acc, <<h::32, rest::binary>>, mask) do
ws_mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
end

for size <- [24, 16, 8] do
defp ws_mask(acc, <<h::unquote(size)>>, mask) do
<<mask::unquote(size), _::binary>> = <<mask::32>>
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
end
end

defp ws_mask(acc, <<>>, _mask) do
acc
end
end
48 changes: 15 additions & 33 deletions lib/bandit/websocket/frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,32 @@ defmodule Bandit.WebSocket.Frame do
end

@impl Bandit.Extractor
@spec deserialize(binary()) :: {:ok, frame()} | {:error, term()}
@spec deserialize(binary(), module()) :: {:ok, frame()} | {:error, term()}
def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 127::7, length::64, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 126::7, length::16, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, length::7, mask::32,
payload::binary-size(length)>>
payload::binary-size(length)>>,
primitive_ops_module
) do
to_frame(fin, compressed, rsv, opcode, mask, payload)
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
end

def deserialize(_msg) do
def deserialize(_msg, _primitive_ops_module) do
{:error, :deserialization_failed}
end

Expand Down Expand Up @@ -155,14 +158,15 @@ defmodule Bandit.WebSocket.Frame do
end
end

defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do
defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, _primitive_ops_module)
when rsv != 0x0 do
{:error, "Received unsupported RSV flags #{rsv}"}
end

defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do
defp to_frame(fin, compressed, 0x0, opcode, mask, payload, primitive_ops_module) do
fin = fin == 0x1
compressed = compressed == 0x1
unmasked_payload = mask(payload, mask)
unmasked_payload = primitive_ops_module.ws_mask(payload, mask)

opcode
|> case do
Expand Down Expand Up @@ -198,26 +202,4 @@ defmodule Bandit.WebSocket.Frame do
defp mask_and_length(length) when length <= 125, do: <<0::1, length::7>>
defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>>
defp mask_and_length(length), do: <<0::1, 127::7, length::64>>

# Note that masking is an involution, so we don't need a separate unmask function
@spec mask(binary(), integer()) :: binary()
def mask(payload, mask)
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
mask(<<>>, payload, mask)
end

defp mask(acc, <<h::32, rest::binary>>, mask) do
mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
end

for size <- [24, 16, 8] do
defp mask(acc, <<h::unquote(size)>>, mask) do
<<mask::unquote(size), _::binary>> = <<mask::32>>
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
end
end

defp mask(acc, <<>>, _mask) do
acc
end
end
5 changes: 4 additions & 1 deletion lib/bandit/websocket/handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ defmodule Bandit.WebSocket.Handler do

connection_opts = Keyword.merge(state.opts.websocket, connection_opts)

primitive_ops_module =
Keyword.get(state.opts.websocket, :primitive_ops_module, Bandit.PrimitiveOps.WebSocket)

state =
state
|> Map.take([:handler_module])
|> Map.put(:extractor, Extractor.new(Frame, connection_opts))
|> Map.put(:extractor, Extractor.new(Frame, primitive_ops_module, connection_opts))

case Connection.init(websock, websock_opts, connection_opts, socket) do
{:continue, connection} ->
Expand Down
Loading

0 comments on commit 898afdc

Please sign in to comment.