diff --git a/lib/bandit/http1/socket.ex b/lib/bandit/http1/socket.ex index e8130c8b..57aec4a5 100644 --- a/lib/bandit/http1/socket.ex +++ b/lib/bandit/http1/socket.ex @@ -12,7 +12,7 @@ defmodule Bandit.HTTP1.Socket do buffer: <<>>, read_state: :unread, write_state: :unsent, - bytes_remaining: nil, + unread_content_length: nil, body_encoding: nil, version: :"HTTP/1.0", send_buffer: nil, @@ -31,7 +31,7 @@ defmodule Bandit.HTTP1.Socket do buffer: iodata(), read_state: read_state(), write_state: write_state(), - bytes_remaining: non_neg_integer() | :chunked | nil, + unread_content_length: non_neg_integer() | :chunked | nil, body_encoding: nil | binary(), version: nil | :"HTTP/1.1" | :"HTTP/1.0", send_buffer: iolist(), @@ -49,20 +49,19 @@ defmodule Bandit.HTTP1.Socket do def read_headers(%@for{read_state: :unread} = socket) do {method, request_target, socket} = do_read_request_line!(socket) {headers, socket} = do_read_headers!(socket) - body_size = get_content_length!(headers) + content_length = get_content_length!(headers) body_encoding = Bandit.Headers.get_header(headers, "transfer-encoding") connection = Bandit.Headers.get_header(headers, "connection") keepalive = should_keepalive?(socket.version, connection) socket = %{socket | keepalive: keepalive} - case {body_size, body_encoding} do + case {content_length, body_encoding} do {nil, nil} -> # No body, so just go straight to 'read' {:ok, method, request_target, headers, %{socket | read_state: :read}} - {body_size, nil} -> - bytes_remaining = body_size - byte_size(socket.buffer) - socket = %{socket | read_state: :headers_read, bytes_remaining: bytes_remaining} + {content_length, nil} -> + socket = %{socket | read_state: :headers_read, unread_content_length: content_length} {:ok, method, request_target, headers, socket} {nil, body_encoding} -> @@ -173,17 +172,19 @@ defmodule Bandit.HTTP1.Socket do defp should_keepalive?(_, _), do: false def read_data( - %@for{read_state: :headers_read, bytes_remaining: bytes_remaining} = socket, + %@for{read_state: :headers_read, unread_content_length: unread_content_length} = socket, opts ) - when is_number(bytes_remaining) do - {to_return, buffer, bytes_remaining} = - do_read_content_length_data!(socket.socket, socket.buffer, bytes_remaining, opts) + when is_number(unread_content_length) do + {to_return, buffer, remaining_unread_content_length} = + do_read_content_length_data!(socket.socket, socket.buffer, unread_content_length, opts) - if byte_size(buffer) == 0 && bytes_remaining == 0 do - {:ok, to_return, %{socket | read_state: :read, buffer: <<>>, bytes_remaining: 0}} + socket = %{socket | buffer: buffer, unread_content_length: remaining_unread_content_length} + + if remaining_unread_content_length == 0 do + {:ok, to_return, %{socket | read_state: :read}} else - {:more, to_return, %{socket | buffer: buffer, bytes_remaining: bytes_remaining}} + {:more, to_return, socket} end end @@ -207,32 +208,35 @@ defmodule Bandit.HTTP1.Socket do def read_data(%@for{} = socket, _opts), do: {:ok, <<>>, socket} @dialyzer {:no_improper_lists, do_read_content_length_data!: 4} - defp do_read_content_length_data!(socket, buffer, bytes_remaining, opts) do - max_desired_bytes = Keyword.get(opts, :length, 8_000_000) + defp do_read_content_length_data!(socket, buffer, unread_content_length, opts) do + max_to_return = min(unread_content_length, Keyword.get(opts, :length, 8_000_000)) cond do - bytes_remaining < 0 -> - # We have read more bytes than content-length suggested should have been sent. This is - # veering into request smuggling territory and should never happen with a well behaved - # client. The safest thing to do is just error - request_error!("Excess body read") + max_to_return == 0 -> + # We have already satisfied our content length + {<<>>, buffer, unread_content_length} - byte_size(buffer) >= max_desired_bytes || bytes_remaining == 0 -> + byte_size(buffer) >= max_to_return -> # We can satisfy the read request entirely from our buffer - bytes_to_return = min(max_desired_bytes, byte_size(buffer)) - <> = buffer - {to_return, rest, bytes_remaining} + <> = buffer + {to_return, rest, unread_content_length - max_to_return} - true -> + byte_size(buffer) < max_to_return -> # We need to read off the wire - bytes_to_read = min(max_desired_bytes - byte_size(buffer), bytes_remaining) read_size = Keyword.get(opts, :read_length, 1_000_000) read_timeout = Keyword.get(opts, :read_timeout) - iolist = read!(socket, bytes_to_read, [], read_size, read_timeout) - to_return = IO.iodata_to_binary([buffer | iolist]) - bytes_remaining = bytes_remaining - (byte_size(to_return) - byte_size(buffer)) - {to_return, <<>>, bytes_remaining} + to_return = + read!(socket, max_to_return - byte_size(buffer), [buffer], read_size, read_timeout) + |> IO.iodata_to_binary() + + # We may have read more than we need to return + if byte_size(to_return) >= max_to_return do + <> = to_return + {to_return, rest, unread_content_length - max_to_return} + else + {to_return, <<>>, unread_content_length - byte_size(to_return)} + end end end diff --git a/test/bandit/http1/request_test.exs b/test/bandit/http1/request_test.exs index cfeeefe5..976c87da 100644 --- a/test/bandit/http1/request_test.exs +++ b/test/bandit/http1/request_test.exs @@ -247,7 +247,7 @@ defmodule HTTP1RequestTest do Transport.send( client, - String.duplicate("GET /hello_world HTTP/1.1\r\nHost: localhost\r\n\r\n", 50) + String.duplicate("GET /send_ok HTTP/1.1\r\nHost: localhost\r\n\r\n", 50) ) for _ <- 1..50 do @@ -258,6 +258,29 @@ defmodule HTTP1RequestTest do end end + test "handles pipeline requests with unread POST bodies", context do + client = SimpleHTTP1Client.tcp_client(context) + + Transport.send( + client, + String.duplicate( + "POST /send_ok HTTP/1.1\r\nHost: localhost\r\nContent-Length:3\r\n\r\nABC", + 50 + ) + ) + + for _ <- 1..50 do + # Need to read the exact size of the expected response because SimpleHTTP1Client + # doesn't track 'rest' bytes and ends up throwing a bunch of responses on the floor + {:ok, bytes} = Transport.recv(client, 152) + assert({:ok, "200 OK", _, _} = SimpleHTTP1Client.parse_response(client, bytes)) + end + end + + def send_ok(conn) do + send_resp(conn, 200, "OK") + end + test "closes connection after max_requests is reached", context do context = http_server(context, http_1_options: [max_requests: 3]) client = SimpleHTTP1Client.tcp_client(context) @@ -1002,29 +1025,6 @@ defmodule HTTP1RequestTest do raise "Shouldn't get here" end - test "handles the case where the declared content length is less than what is sent", - context do - output = - capture_log(fn -> - client = SimpleHTTP1Client.tcp_client(context) - - Transport.send( - client, - "POST /long_body HTTP/1.1\r\nhost: localhost\r\ncontent-length: 3\r\n\r\nABCDE" - ) - - assert {:ok, "400 Bad Request", _, ""} = SimpleHTTP1Client.recv_reply(client) - Process.sleep(100) - end) - - assert output =~ "(Bandit.HTTPError) Excess body read" - end - - def long_body(conn) do - Plug.Conn.read_body(conn) - raise "should not get here" - end - test "reading request body multiple times works as expected", context do response = Req.post!(context.req, url: "/multiple_body_read", body: "OK")