Skip to content

Commit 898afdc

Browse files
authored
Add support for running optimized native code for websocket mask (#394)
* Add PrimitiveOps behaviour and default impl * Move into websocket options * Move primitive_ops_module in Extractor.new/3
1 parent cd855ec commit 898afdc

File tree

7 files changed

+132
-106
lines changed

7 files changed

+132
-106
lines changed

lib/bandit.ex

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ defmodule Bandit do
222222
@http_keys ~w(compress deflate_options log_exceptions_with_status_codes log_protocol_errors log_client_closures)a
223223
@http_1_keys ~w(enabled max_request_line_length max_header_length max_header_count max_requests clear_process_dict gc_every_n_keepalive_requests log_unknown_messages)a
224224
@http_2_keys ~w(enabled max_header_block_size max_requests default_local_settings)a
225-
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress)a
225+
@websocket_keys ~w(enabled max_frame_size validate_text_frames compress primitive_ops_module)a
226226
@thousand_island_keys ThousandIsland.ServerConfig.__struct__()
227227
|> Map.from_struct()
228228
|> Map.keys()

lib/bandit/extractor.ex

+10-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ defmodule Bandit.Extractor do
99
| {:error, term()}
1010
| :more
1111

12-
@callback deserialize(binary()) :: deserialize_result()
12+
@callback deserialize(binary(), primitive_ops_module :: module()) :: deserialize_result()
1313

1414
@type t :: %__MODULE__{
1515
header: binary(),
@@ -18,7 +18,8 @@ defmodule Bandit.Extractor do
1818
required_length: non_neg_integer(),
1919
mode: :header_parsing | :payload_parsing,
2020
max_frame_size: non_neg_integer(),
21-
frame_parser: atom()
21+
frame_parser: atom(),
22+
primitive_ops_module: module()
2223
}
2324

2425
defstruct header: <<>>,
@@ -27,15 +28,17 @@ defmodule Bandit.Extractor do
2728
required_length: 0,
2829
mode: :header_parsing,
2930
max_frame_size: 0,
30-
frame_parser: nil
31+
frame_parser: nil,
32+
primitive_ops_module: nil
3133

32-
@spec new(module(), Keyword.t()) :: t()
33-
def new(frame_parser, opts) do
34+
@spec new(module(), module(), Keyword.t()) :: t()
35+
def new(frame_parser, primitive_ops_module, opts) do
3436
max_frame_size = Keyword.get(opts, :max_frame_size, 0)
3537

3638
%__MODULE__{
3739
max_frame_size: max_frame_size,
38-
frame_parser: frame_parser
40+
frame_parser: frame_parser,
41+
primitive_ops_module: primitive_ops_module
3942
}
4043
end
4144

@@ -79,7 +82,7 @@ defmodule Bandit.Extractor do
7982
<<payload::binary-size(required_length), rest::binary>> =
8083
IO.iodata_to_binary(state.payload)
8184

82-
frame = state.frame_parser.deserialize(state.header <> payload)
85+
frame = state.frame_parser.deserialize(state.header <> payload, state.primitive_ops_module)
8386
state = transition_to_header_parsing(state, rest)
8487

8588
{state, frame}

lib/bandit/primitive_ops/websocket.ex

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
defmodule Bandit.PrimitiveOps.WebSocket do
2+
@moduledoc """
3+
WebSocket primitive operations behaviour and default implementation
4+
"""
5+
6+
@doc """
7+
WebSocket masking according to [RFC6455§5.3](https://www.rfc-editor.org/rfc/rfc6455#section-5.3)
8+
"""
9+
@callback ws_mask(payload :: binary(), mask :: integer()) :: binary()
10+
11+
@behaviour __MODULE__
12+
13+
# Note that masking is an involution, so we don't need a separate unmask function
14+
@impl true
15+
def ws_mask(payload, mask)
16+
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
17+
ws_mask(<<>>, payload, mask)
18+
end
19+
20+
defp ws_mask(acc, <<h::32, rest::binary>>, mask) do
21+
ws_mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
22+
end
23+
24+
for size <- [24, 16, 8] do
25+
defp ws_mask(acc, <<h::unquote(size)>>, mask) do
26+
<<mask::unquote(size), _::binary>> = <<mask::32>>
27+
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
28+
end
29+
end
30+
31+
defp ws_mask(acc, <<>>, _mask) do
32+
acc
33+
end
34+
end

lib/bandit/websocket/frame.ex

+15-33
Original file line numberDiff line numberDiff line change
@@ -73,29 +73,32 @@ defmodule Bandit.WebSocket.Frame do
7373
end
7474

7575
@impl Bandit.Extractor
76-
@spec deserialize(binary()) :: {:ok, frame()} | {:error, term()}
76+
@spec deserialize(binary(), module()) :: {:ok, frame()} | {:error, term()}
7777
def deserialize(
7878
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 127::7, length::64, mask::32,
79-
payload::binary-size(length)>>
79+
payload::binary-size(length)>>,
80+
primitive_ops_module
8081
) do
81-
to_frame(fin, compressed, rsv, opcode, mask, payload)
82+
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
8283
end
8384

8485
def deserialize(
8586
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, 126::7, length::16, mask::32,
86-
payload::binary-size(length)>>
87+
payload::binary-size(length)>>,
88+
primitive_ops_module
8789
) do
88-
to_frame(fin, compressed, rsv, opcode, mask, payload)
90+
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
8991
end
9092

9193
def deserialize(
9294
<<fin::1, compressed::1, rsv::2, opcode::4, 1::1, length::7, mask::32,
93-
payload::binary-size(length)>>
95+
payload::binary-size(length)>>,
96+
primitive_ops_module
9497
) do
95-
to_frame(fin, compressed, rsv, opcode, mask, payload)
98+
to_frame(fin, compressed, rsv, opcode, mask, payload, primitive_ops_module)
9699
end
97100

98-
def deserialize(_msg) do
101+
def deserialize(_msg, _primitive_ops_module) do
99102
{:error, :deserialization_failed}
100103
end
101104

@@ -155,14 +158,15 @@ defmodule Bandit.WebSocket.Frame do
155158
end
156159
end
157160

158-
defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload) when rsv != 0x0 do
161+
defp to_frame(_fin, _compressed, rsv, _opcode, _mask, _payload, _primitive_ops_module)
162+
when rsv != 0x0 do
159163
{:error, "Received unsupported RSV flags #{rsv}"}
160164
end
161165

162-
defp to_frame(fin, compressed, 0x0, opcode, mask, payload) do
166+
defp to_frame(fin, compressed, 0x0, opcode, mask, payload, primitive_ops_module) do
163167
fin = fin == 0x1
164168
compressed = compressed == 0x1
165-
unmasked_payload = mask(payload, mask)
169+
unmasked_payload = primitive_ops_module.ws_mask(payload, mask)
166170

167171
opcode
168172
|> case do
@@ -198,26 +202,4 @@ defmodule Bandit.WebSocket.Frame do
198202
defp mask_and_length(length) when length <= 125, do: <<0::1, length::7>>
199203
defp mask_and_length(length) when length <= 65_535, do: <<0::1, 126::7, length::16>>
200204
defp mask_and_length(length), do: <<0::1, 127::7, length::64>>
201-
202-
# Note that masking is an involution, so we don't need a separate unmask function
203-
@spec mask(binary(), integer()) :: binary()
204-
def mask(payload, mask)
205-
when is_binary(payload) and is_integer(mask) and mask >= 0x00000000 and mask <= 0xFFFFFFFF do
206-
mask(<<>>, payload, mask)
207-
end
208-
209-
defp mask(acc, <<h::32, rest::binary>>, mask) do
210-
mask(<<acc::binary, (<<Bitwise.bxor(h, mask)::32>>)>>, rest, mask)
211-
end
212-
213-
for size <- [24, 16, 8] do
214-
defp mask(acc, <<h::unquote(size)>>, mask) do
215-
<<mask::unquote(size), _::binary>> = <<mask::32>>
216-
<<acc::binary, (<<Bitwise.bxor(h, mask)::unquote(size)>>)>>
217-
end
218-
end
219-
220-
defp mask(acc, <<>>, _mask) do
221-
acc
222-
end
223205
end

lib/bandit/websocket/handler.ex

+4-1
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ defmodule Bandit.WebSocket.Handler do
1717

1818
connection_opts = Keyword.merge(state.opts.websocket, connection_opts)
1919

20+
primitive_ops_module =
21+
Keyword.get(state.opts.websocket, :primitive_ops_module, Bandit.PrimitiveOps.WebSocket)
22+
2023
state =
2124
state
2225
|> Map.take([:handler_module])
23-
|> Map.put(:extractor, Extractor.new(Frame, connection_opts))
26+
|> Map.put(:extractor, Extractor.new(Frame, primitive_ops_module, connection_opts))
2427

2528
case Connection.init(websock, websock_opts, connection_opts, socket) do
2629
{:continue, connection} ->

0 commit comments

Comments
 (0)