Skip to content

Commit

Permalink
refactor: websocket; allow passing the frame type up the stack
Browse files Browse the repository at this point in the history
  • Loading branch information
braindigitalis committed Sep 25, 2024
1 parent 27ae746 commit 8ec080b
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 36 deletions.
2 changes: 1 addition & 1 deletion include/dpp/discordclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ class DPP_EXPORT discord_client : public websocket_client
* @param buffer The entire buffer content from the websocket client
* @returns True if a frame has been handled
*/
virtual bool handle_frame(const std::string &buffer);
virtual bool handle_frame(const std::string &buffer, ws_opcode opcode);

/**
* @brief Handle a websocket error.
Expand Down
2 changes: 1 addition & 1 deletion include/dpp/discordvoiceclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client
* @return bool True if a frame has been handled
* @throw dpp::exception If there was an error processing the frame, or connection to UDP socket failed
*/
virtual bool handle_frame(const std::string &buffer);
virtual bool handle_frame(const std::string &buffer, ws_opcode opcode);

/**
* @brief Handle a websocket error.
Expand Down
2 changes: 1 addition & 1 deletion include/dpp/sslclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class DPP_EXPORT ssl_client
* @param data Data to be written to the buffer.
* @note The data may not be written immediately and may be written at a later time to the socket.
*/
virtual void write(const std::string_view data);
void socket_write(const std::string_view data);

/**
* @brief Close socket connection
Expand Down
12 changes: 9 additions & 3 deletions include/dpp/wsclient.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@ enum ws_opcode : uint8_t {
/**
* @brief Low level pong.
*/
OP_PONG = 0x0a
OP_PONG = 0x0a,

/**
* @brief Automatic selection of type
*/
OP_AUTO = 0xff,
};

/**
Expand Down Expand Up @@ -189,7 +194,7 @@ class DPP_EXPORT websocket_client : public ssl_client {
* @brief Write to websocket. Encapsulates data in frames if the status is CONNECTED.
* @param data The data to send.
*/
virtual void write(const std::string_view data);
virtual void write(const std::string_view data, ws_opcode _opcode = OP_AUTO);

/**
* @brief Processes incoming frames from the SSL socket input buffer.
Expand All @@ -206,9 +211,10 @@ class DPP_EXPORT websocket_client : public ssl_client {
* @brief Receives raw frame content only without headers
*
* @param buffer The buffer contents
* @param opcode Frame type, e.g. OP_TEXT, OP_BINARY
* @return True if the frame was successfully handled. False if no valid frame is in the buffer.
*/
virtual bool handle_frame(const std::string& buffer);
virtual bool handle_frame(const std::string& buffer, ws_opcode opcode);

/**
* @brief Called upon error frame.
Expand Down
8 changes: 4 additions & 4 deletions src/dpp/discordclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ void discord_client::run()
this->thread_id = runner->native_handle();
}

bool discord_client::handle_frame(const std::string &buffer)
bool discord_client::handle_frame(const std::string &buffer, ws_opcode opcode)
{
std::string& data = (std::string&)buffer;

Expand Down Expand Up @@ -340,7 +340,7 @@ bool discord_client::handle_frame(const std::string &buffer)
}
}
};
this->write(jsonobj_to_string(obj));
this->write(jsonobj_to_string(obj), protocol == ws_etf ? OP_BINARY : OP_TEXT);
resumes++;
} else {
/* Full connect */
Expand Down Expand Up @@ -369,7 +369,7 @@ bool discord_client::handle_frame(const std::string &buffer)
}
}
};
this->write(jsonobj_to_string(obj));
this->write(jsonobj_to_string(obj), protocol == ws_etf ? OP_BINARY : OP_TEXT);
this->connect_time = creator->last_identify = time(nullptr);
reconnects++;
}
Expand Down Expand Up @@ -539,7 +539,7 @@ void discord_client::one_second_timer()
ping_start = utility::time_f();
last_ping_message.clear();
}
this->write(message);
this->write(message, protocol == ws_etf ? OP_BINARY : OP_TEXT);
}
}

Expand Down
16 changes: 7 additions & 9 deletions src/dpp/discordvoiceclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,17 +486,15 @@ int discord_voice_client::udp_recv(char* data, size_t max_length)
return (int) recv(this->fd, data, (int)max_length, 0);
}

bool discord_voice_client::handle_frame(const std::string &data)
bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcode)
{
log(dpp::ll_trace, std::string("R: ") + data);
json j;

/**
* Because all discord JSON must be valid UTF-8, if we see a packet with the 2nd character
* being less than 32 (' '), then we know it is a binary MLS frame, as all the binary frame
* opcodes are purposefully less than 32. We then try and parse it as MLS binary.
* MLS frames come in as type OP_BINARY, we can also reply to them as type OP_BINARY.
*/
if (data.size() >= sizeof(dave_binary_header_t) && data[2] <= voice_client_dave_mls_invalid_commit_welcome) {
if (opcode == OP_BINARY && data.size() >= sizeof(dave_binary_header_t)) {

/* Debug, remove once this is working */
std::cout << dpp::utility::debug_dump((uint8_t*)(data.data()), data.length()) << "\n";
Expand Down Expand Up @@ -627,7 +625,7 @@ bool discord_voice_client::handle_frame(const std::string &data)
}
}
};
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace));
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
} else {
log(dpp::ll_debug, "Connecting new voice session (DAVE: " + std::string(dave_version == dave_version_1 ? "Enabled" : "Disabled") + ")...");
json obj = {
Expand All @@ -643,7 +641,7 @@ bool discord_voice_client::handle_frame(const std::string &data)
}
}
};
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace));
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
}
this->connect_time = time(nullptr);
}
Expand Down Expand Up @@ -744,7 +742,7 @@ bool discord_voice_client::handle_frame(const std::string &data)
}
}
}
}).dump(-1, ' ', false, json::error_handler_t::replace));
}).dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
}
}
break;
Expand Down Expand Up @@ -1204,7 +1202,7 @@ void discord_voice_client::one_second_timer()
if (!message_queue.empty()) {
std::string message = message_queue.front();
message_queue.pop_front();
this->write(message);
this->write(message, OP_TEXT);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/dpp/httpsclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void https_client::connect()
map_headers += k + ": " + v + "\r\n";
}
if (this->sfd != SOCKET_ERROR) {
this->write(
this->socket_write(
this->request_type + " " + this->path + " HTTP/" + http_protocol + "\r\n"
"Host: " + this->hostname + "\r\n"
"pragma: no-cache\r\n"
Expand Down
2 changes: 1 addition & 1 deletion src/dpp/sslclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ void ssl_client::connect()
}
}

void ssl_client::write(const std::string_view data)
void ssl_client::socket_write(const std::string_view data)
{
/* If we are in nonblocking mode, append to the buffer,
* otherwise just use SSL_write directly. The only time we
Expand Down
30 changes: 15 additions & 15 deletions src/dpp/wsclient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void websocket_client::connect()
{
state = HTTP_HEADERS;
/* Send headers synchronously */
this->write(
this->socket_write(
"GET " + this->path + " HTTP/1.1\r\n"
"Host: " + this->hostname + "\r\n"
"pragma: no-cache\r\n"
Expand All @@ -73,7 +73,7 @@ void websocket_client::connect()
);
}

bool websocket_client::handle_frame(const std::string& buffer)
bool websocket_client::handle_frame(const std::string& buffer, ws_opcode opcode)
{
/* This is a stub for classes that derive the websocket client */
return true;
Expand Down Expand Up @@ -111,17 +111,17 @@ size_t websocket_client::fill_header(unsigned char* outbuf, size_t sendlength, w
}


void websocket_client::write(const std::string_view data)
void websocket_client::write(const std::string_view data, ws_opcode _opcode)
{
if (state == HTTP_HEADERS) {
/* Simple write */
ssl_client::write(data);
ssl_client::socket_write(data);
} else {
unsigned char out[MAXHEADERSIZE];
size_t s = this->fill_header(out, data.length(), this->data_opcode);
size_t s = this->fill_header(out, data.length(), _opcode == OP_AUTO ? this->data_opcode : _opcode);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(data);
ssl_client::socket_write(header);
ssl_client::socket_write(data);
}
}

Expand Down Expand Up @@ -175,7 +175,7 @@ bool websocket_client::handle_buffer(std::string& buffer)
}
} else if (state == CONNECTED) {
/* Process packets until we can't (buffer will erase data until parseheader returns false) */
while (this->parseheader(buffer)){}
while (this->parseheader(buffer)) { }
}

return true;
Expand Down Expand Up @@ -249,7 +249,7 @@ bool websocket_client::parseheader(std::string& data)
handle_ping(data.substr(payloadstartoffset, len));
} else if ((opcode & ~WS_FINBIT) != OP_PONG) { /* Otherwise, handle everything else apart from a PONG. */
/* Pass this frame to the deriving class */
this->handle_frame(data.substr(payloadstartoffset, len));
this->handle_frame(data.substr(payloadstartoffset, len), static_cast<ws_opcode>(opcode & ~WS_FINBIT));
}

/* Remove this frame from the input buffer */
Expand Down Expand Up @@ -286,8 +286,8 @@ void websocket_client::one_second_timer()
std::string payload = "keepalive";
size_t s = this->fill_header(out, payload.length(), OP_PING);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(payload);
ssl_client::socket_write(header);
ssl_client::socket_write(payload);
}
}

Expand All @@ -297,8 +297,8 @@ void websocket_client::handle_ping(const std::string &payload)
unsigned char out[MAXHEADERSIZE];
size_t s = this->fill_header(out, payload.length(), OP_PONG);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(payload);
ssl_client::socket_write(header);
ssl_client::socket_write(payload);
}

void websocket_client::send_close_packet()
Expand All @@ -312,8 +312,8 @@ void websocket_client::send_close_packet()

size_t s = this->fill_header(out, payload.length(), OP_CLOSE);
std::string header((const char*)out, s);
ssl_client::write(header);
ssl_client::write(payload);
ssl_client::socket_write(header);
ssl_client::socket_write(payload);
}

void websocket_client::error(uint32_t errorcode)
Expand Down

0 comments on commit 8ec080b

Please sign in to comment.