Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/crow/app.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,12 +627,12 @@ namespace crow
}


void add_websocket(crow::websocket::connection* conn)
void add_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.push_back(conn);
}

void remove_websocket(crow::websocket::connection* conn)
void remove_websocket(std::shared_ptr<websocket::connection> conn)
{
websockets_.erase(std::remove(websockets_.begin(), websockets_.end(), conn), websockets_.end());
}
Expand Down Expand Up @@ -846,7 +846,7 @@ namespace crow
bool server_started_{false};
std::condition_variable cv_started_;
std::mutex start_mutex_;
std::vector<crow::websocket::connection*> websockets_;
std::vector<std::shared_ptr<websocket::connection>> websockets_;
};

/// \brief Alias of Crow<Middlewares...>. Useful if you want
Expand Down
8 changes: 5 additions & 3 deletions include/crow/routing.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,17 +445,19 @@ namespace crow // NOTE: Already documented in "crow/app.h"
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<SocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}

void handle_upgrade(const request& req, response&, UnixSocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
new crow::websocket::Connection<UnixSocketAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<UnixSocketAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}

#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
new crow::websocket::Connection<SSLAdaptor, App>(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
crow::websocket::Connection<SSLAdaptor, App>::create(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_, mirror_protocols_);
}
#endif

Expand Down
159 changes: 82 additions & 77 deletions include/crow/websocket.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <array>
#include <memory>
#include "crow/logging.h"
#include "crow/socket_adaptors.h"
#include "crow/http_request.h"
Expand Down Expand Up @@ -102,36 +103,34 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// A websocket connection.

template<typename Adaptor, typename Handler>
class Connection : public connection
class Connection : public connection, public std::enable_shared_from_this<Connection<Adaptor, Handler>>
{
public:
/// Constructor for a connection.

/// Factory for a connection.
///
/// Requires a request with an "Upgrade: websocket" header.<br>
/// Automatically handles the handshake.
Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
static void create(const crow::request& req, Adaptor adaptor, Handler* handler,
uint64_t max_payload, const std::vector<std::string>& subprotocols,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler,
bool mirror_protocols)
{
auto conn = std::shared_ptr<Connection>(new Connection(std::move(adaptor),
handler, max_payload,
std::move(open_handler),
std::move(message_handler),
std::move(close_handler),
std::move(error_handler),
std::move(accept_handler)));

// Perform handshake validation
if (!utility::string_equals(req.get_header_value("upgrade"), "websocket"))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}

Expand All @@ -142,26 +141,24 @@ namespace crow // NOTE: Already documented in "crow/app.h"
auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end());
if (subprotocol != subprotocols.end())
{
subprotocol_ = *subprotocol;
conn->subprotocol_ = *subprotocol;
}
}

if (mirror_protocols & !requested_subprotocols_header.empty())
{
subprotocol_ = requested_subprotocols_header;
conn->subprotocol_ = requested_subprotocols_header;
}

if (accept_handler_)
if (conn->accept_handler_)
{
void* ud = nullptr;
if (!accept_handler_(req, &ud))
if (!conn->accept_handler_(req, &ud))
{
adaptor_.close();
handler_->remove_websocket(this);
delete this;
conn->adaptor_.close();
return;
}
userdata(ud);
conn->userdata(ud);
}

// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
Expand All @@ -172,22 +169,11 @@ namespace crow // NOTE: Already documented in "crow/app.h"
uint8_t digest[20];
s.getDigestBytes(digest);

start(crow::utility::base64encode((unsigned char*)digest, 20));
conn->handler_->add_websocket(conn);
conn->start(crow::utility::base64encode((unsigned char*)digest, 20));
}

~Connection() noexcept override
{
// Do not modify anchor_ here since writing shared_ptr is not atomic.
auto watch = std::weak_ptr<void>{anchor_};

// Wait until all unhandled asynchronous operations to join.
// As the deletion occurs inside 'check_destroy()', which already locks
// anchor, use count can be 1 on valid deletion context.
while (watch.use_count() > 2) // 1 for 'check_destroy() routine', 1 for 'this->anchor_'
{
std::this_thread::yield();
}
}
~Connection() noexcept override = default;

template<typename Callable>
struct WeakWrappedMessage
Expand Down Expand Up @@ -717,38 +703,38 @@ namespace crow // NOTE: Already documented in "crow/app.h"
/// Also destroys the object if the Close flag is set.
void do_write()
{
if (sending_buffers_.empty())
if (write_buffers_.empty()) return;

sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
sending_buffers_.swap(write_buffers_);
std::vector<asio::const_buffer> buffers;
buffers.reserve(sending_buffers_.size());
for (auto& s : sending_buffers_)
{
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[&, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
auto anchor = watch.lock();
if (anchor == nullptr) { return; }

sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
buffers.emplace_back(asio::buffer(s));
}
auto watch = std::weak_ptr<void>{anchor_};
asio::async_write(
adaptor_.socket(), buffers,
[this, watch](const error_code& ec, std::size_t /*bytes_transferred*/) {
auto anchor = watch.lock();
if (anchor == nullptr)
return;

if (!ec && !close_connection_)
{
sending_buffers_.clear();
if (!write_buffers_.empty())
do_write();
if (has_sent_close_)
close_connection_ = true;
}
else
{
sending_buffers_.clear();
close_connection_ = true;
check_destroy();
}
});
}

/// Destroy the Connection.
Expand All @@ -757,11 +743,14 @@ namespace crow // NOTE: Already documented in "crow/app.h"
// Note that if the close handler was not yet called at this point we did not receive a close packet (or send one)
// and thus we use ClosedAbnormally unless instructed otherwise
if (!is_close_handler_called_)
{
if (close_handler_)
{
close_handler_(*this, "uncleanly", code);
handler_->remove_websocket(this);
if (sending_buffers_.empty() && !is_reading)
delete this;
}
}

handler_->remove_websocket(this->shared_from_this());
}


Expand Down Expand Up @@ -796,6 +785,22 @@ namespace crow // NOTE: Already documented in "crow/app.h"
}

private:
Connection(Adaptor&& adaptor, Handler* handler, uint64_t max_payload,
std::function<void(crow::websocket::connection&)> open_handler,
std::function<void(crow::websocket::connection&, const std::string&, bool)> message_handler,
std::function<void(crow::websocket::connection&, const std::string&, uint16_t)> close_handler,
std::function<void(crow::websocket::connection&, const std::string&)> error_handler,
std::function<bool(const crow::request&, void**)> accept_handler):
adaptor_(std::move(adaptor)),
handler_(handler),
max_payload_bytes_(max_payload),
open_handler_(std::move(open_handler)),
message_handler_(std::move(message_handler)),
close_handler_(std::move(close_handler)),
error_handler_(std::move(error_handler)),
accept_handler_(std::move(accept_handler))
{}

Adaptor adaptor_;
Handler* handler_;

Expand Down
Loading