Skip to content

Commit 83a86f6

Browse files
refactor: dont use strings and uint64_t in dave, use snowflake type (#1307)
1 parent 7d0128c commit 83a86f6

File tree

4 files changed

+50
-59
lines changed

4 files changed

+50
-59
lines changed

include/dpp/discordvoiceclient.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,13 @@ class DPP_EXPORT discord_voice_client : public websocket_client
487487
* @brief The list of users that have E2EE potentially enabled for
488488
* DAVE protocol.
489489
*/
490-
std::set<std::string> dave_mls_user_list;
490+
std::set<dpp::snowflake> dave_mls_user_list;
491491

492492
/**
493493
* @brief The list of users that have left the voice channel but
494494
* not yet removed from MLS group.
495495
*/
496-
std::set<std::string> dave_mls_pending_remove_list;
496+
std::set<dpp::snowflake> dave_mls_pending_remove_list;
497497

498498
/**
499499
* @brief File descriptor for UDP connection

src/dpp/dave/session.cpp

+21-30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
#include <iostream>
3030
#include <mls/crypto.h>
3131
#include <mls/messages.h>
32+
#include <dpp/export.h>
33+
#include <dpp/snowflake.h>
3234
#include <mls/state.h>
3335
#include <dpp/cluster.h>
3436
#include "mls_key_ratchet.h"
@@ -50,20 +52,20 @@ struct queued_proposal {
5052
::mlspp::bytes_ns::bytes ref;
5153
};
5254

53-
session::session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept
55+
session::session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept
5456
: signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster)
5557
{
5658
creator.log(dpp::ll_debug, "Creating a new MLS session");
5759
}
5860

5961
session::~session() noexcept = default;
6062

61-
void session::init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
63+
void session::init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
6264
reset();
6365

6466
bot_user_id = self_user_id;
6567

66-
creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + std::to_string(group_id));
68+
creator.log(dpp::ll_debug, "Initializing MLS session with protocol version " + std::to_string(version) + " and group ID " + group_id.str());
6769
session_protocol_version = version;
6870
session_group_id = std::move(big_endian_bytes_from(group_id).as_vec());
6971

@@ -123,7 +125,7 @@ catch (const std::exception& e) {
123125
return;
124126
}
125127

126-
std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept
128+
std::optional<std::vector<uint8_t>> session::process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
127129
try {
128130
if (!pending_group_state && !current_state) {
129131
creator.log(dpp::ll_debug, "Cannot process proposals without any pending or established MLS group state");
@@ -183,9 +185,7 @@ try {
183185
for (const auto& proposal_message : messages) {
184186
auto validated_content = state_with_proposals->unwrap(proposal_message);
185187

186-
if (!validate_proposal_message(validated_content.authenticated_content(),
187-
*state_with_proposals,
188-
recognised_user_ids)) {
188+
if (!validate_proposal_message(validated_content.authenticated_content(), *state_with_proposals, recognised_user_ids)) {
189189
return std::nullopt;
190190
}
191191

@@ -238,9 +238,9 @@ catch (const std::exception& e) {
238238
return std::nullopt;
239239
}
240240

241-
bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const
241+
bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const
242242
{
243-
std::string uid = user_credential_to_string(cred, session_protocol_version);
243+
dpp::snowflake uid(user_credential_to_string(cred, session_protocol_version));
244244
if (uid.empty()) {
245245
creator.log(dpp::ll_warning, "Attempted to verify credential of unexpected type");
246246
return false;
@@ -254,7 +254,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<st
254254
return true;
255255
}
256256

257-
bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const {
257+
bool session::validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const {
258258
if (message.wire_format != ::mlspp::WireFormat::mls_public_message) {
259259
creator.log(dpp::ll_warning, "MLS proposal message must be PublicMessage");
260260
TRACK_MLS_ERROR("Invalid proposal wire format");
@@ -357,7 +357,7 @@ catch (const std::exception& e) {
357357
return failed_t{};
358358
}
359359

360-
std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept
360+
std::optional<roster_map> session::process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept
361361
try {
362362
if (!has_cryptographic_state_for_welcome()) {
363363
creator.log(dpp::ll_warning, "Missing local crypto state necessary to process MLS welcome");
@@ -461,7 +461,7 @@ bool session::has_cryptographic_state_for_welcome() const noexcept
461461
return join_key_package && join_init_private_key && signature_private_key && hpke_private_key;
462462
}
463463

464-
bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const
464+
bool session::verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const
465465
{
466466
if (!mls_external_sender) {
467467
creator.log(dpp::ll_warning, "Cannot verify MLS welcome without an external sender");
@@ -502,13 +502,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::st
502502
return true;
503503
}
504504

505-
void session::init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
505+
void session::init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
506506
try {
507507
auto ciphersuite = ciphersuite_for_protocol_version(session_protocol_version);
508508

509509
if (!transient_key) {
510510
if (!signing_key_id.empty()) {
511-
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id, session_protocol_version);
511+
transient_key = get_persisted_key_pair(creator, key_pair_context, signing_key_id.str(), session_protocol_version);
512512
if (!transient_key) {
513513
creator.log(dpp::ll_warning, "Did not receive MLS signature private key from get_persisted_key_pair; aborting");
514514
return;
@@ -522,7 +522,7 @@ try {
522522

523523
signature_private_key = transient_key;
524524

525-
auto self_credential = create_user_credential(self_user_id, session_protocol_version);
525+
auto self_credential = create_user_credential(self_user_id.str(), session_protocol_version);
526526
hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate(ciphersuite));
527527
self_leaf_node = std::make_unique<::mlspp::LeafNode>(
528528
ciphersuite, hpke_private_key->public_key, signature_private_key->public_key, std::move(self_credential),
@@ -608,7 +608,7 @@ catch (const std::exception& e) {
608608
return {};
609609
}
610610

611-
std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string const& user_id) const noexcept
611+
std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(dpp::snowflake user_id) const noexcept
612612
{
613613
if (!current_state) {
614614
creator.log(dpp::ll_warning, "Cannot get key ratchet without an established MLS group");
@@ -617,7 +617,7 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
617617

618618
// change the string user ID to a little endian 64 bit user ID
619619
// TODO: Make this use dpp::snowflake
620-
auto u64_user_id = strtoull(user_id.c_str(), nullptr, 10);
620+
uint64_t u64_user_id = user_id;
621621
auto user_id_bytes = ::mlspp::bytes_ns::bytes(sizeof(u64_user_id));
622622
memcpy(user_id_bytes.data(), &u64_user_id, sizeof(u64_user_id));
623623

@@ -629,14 +629,14 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
629629
return std::make_unique<mls_key_ratchet>(creator, current_state->cipher_suite(), std::move(secret));
630630
}
631631

632-
void session::get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept
632+
void session::get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept
633633
try {
634634
if (!current_state || !signature_private_key) {
635635
throw std::invalid_argument("No established MLS group");
636636
}
637637

638-
uint64_t remote_user_id = strtoull(user_id.c_str(), nullptr, 10);
639-
uint64_t self_user_id = strtoull(bot_user_id.c_str(), nullptr, 10);
638+
uint64_t remote_user_id = user_id;
639+
uint64_t self_user_id = bot_user_id;
640640

641641
auto it = roster.find(remote_user_id);
642642
if (it == roster.end()) {
@@ -687,16 +687,7 @@ try {
687687

688688
std::vector<uint8_t> out(hash_len);
689689

690-
int ret = EVP_PBE_scrypt((const char*)data.data(),
691-
data.size(),
692-
salt,
693-
sizeof(salt),
694-
N,
695-
r,
696-
p,
697-
max_mem,
698-
out.data(),
699-
out.size());
690+
int ret = EVP_PBE_scrypt((const char*)data.data(), data.size(), salt, sizeof(salt), N, r, p, max_mem, out.data(), out.size());
700691

701692
if (ret == 1) {
702693
callback(out);

src/dpp/dave/session.h

+14-12
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
#include <vector>
3434
#include <map>
3535
#include <set>
36+
#include <dpp/export.h>
37+
#include <dpp/snowflake.h>
3638
#include "persisted_key_pair.h"
3739
#include "key_ratchet.h"
3840
#include "version.h"
@@ -73,7 +75,7 @@ class session { // NOLINT
7375
* @param auth_session_id auth session id (set to empty string to use a transient key pair)
7476
* @param callback callback for failure
7577
*/
76-
session(dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept;
78+
session(dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept;
7779

7880
/**
7981
* @brief Destructor
@@ -90,7 +92,7 @@ class session { // NOLINT
9092
* @param self_user_id bot's user id
9193
* @param transient_key transient private key
9294
*/
93-
void init(protocol_version version, uint64_t group_id, std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
95+
void init(protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
9496

9597
/**
9698
* @brief Reset the session to defaults
@@ -129,7 +131,7 @@ class session { // NOLINT
129131
* @param recognised_user_ids list of recognised user IDs
130132
* @return optional vector to send in reply as commit welcome
131133
*/
132-
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<std::string> const& recognised_user_ids) noexcept;
134+
std::optional<std::vector<uint8_t>> process_proposals(std::vector<uint8_t> proposals, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;
133135

134136
/**
135137
* @brief Process commit message from discord websocket
@@ -144,7 +146,7 @@ class session { // NOLINT
144146
* @param recognised_user_ids Recognised user ID list
145147
* @return roster list of people in the vc
146148
*/
147-
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<std::string> const& recognised_user_ids) noexcept;
149+
std::optional<roster_map> process_welcome(std::vector<uint8_t> welcome, std::set<dpp::snowflake> const& recognised_user_ids) noexcept;
148150

149151
/**
150152
* @brief Get the bot user's key package for sending to websocket
@@ -157,7 +159,7 @@ class session { // NOLINT
157159
* @param user_id User id to get ratchet for
158160
* @return The user's key ratchet for use in an encryptor or decryptor
159161
*/
160-
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(std::string const& user_id) const noexcept;
162+
[[nodiscard]] std::unique_ptr<key_ratchet_interface> get_key_ratchet(dpp::snowflake user_id) const noexcept;
161163

162164
/**
163165
* @brief callback for completion of pairwise fingerprint
@@ -172,15 +174,15 @@ class session { // NOLINT
172174
* @param user_id User ID to get fingerprint for
173175
* @param callback Callback for completion
174176
*/
175-
void get_pairwise_fingerprint(uint16_t version, std::string const& user_id, pairwise_fingerprint_callback callback) const noexcept;
177+
void get_pairwise_fingerprint(uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept;
176178

177179
private:
178180
/**
179181
* @brief Initialise leaf node
180182
* @param self_user_id Bot user id
181183
* @param transient_key Transient key
182184
*/
183-
void init_leaf_node(std::string const& self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
185+
void init_leaf_node(dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept;
184186

185187
/**
186188
* @brief Reset join key
@@ -204,7 +206,7 @@ class session { // NOLINT
204206
* @param recognised_user_ids list of recognised user IDs
205207
* @return
206208
*/
207-
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<std::string> const& recognised_user_ids) const;
209+
[[nodiscard]] bool is_recognized_user_id(const ::mlspp::Credential& cred, std::set<dpp::snowflake> const& recognised_user_ids) const;
208210

209211
/**
210212
* @brief Validate proposals message
@@ -213,15 +215,15 @@ class session { // NOLINT
213215
* @param recognised_user_ids recognised list of user IDs
214216
* @return true if validated
215217
*/
216-
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<std::string> const& recognised_user_ids) const;
218+
[[nodiscard]] bool validate_proposal_message(::mlspp::AuthenticatedContent const& message, ::mlspp::State const& target_state, std::set<dpp::snowflake> const& recognised_user_ids) const;
217219

218220
/**
219221
* @brief Verify that welcome state is valid
220222
* @param state current state
221223
* @param recognised_user_ids list of recognised user IDs
222224
* @return
223225
*/
224-
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<std::string> const& recognised_user_ids) const;
226+
[[nodiscard]] bool verify_welcome_state(::mlspp::State const& state, std::set<dpp::snowflake> const& recognised_user_ids) const;
225227

226228
/**
227229
* @brief Check if can process a commit now
@@ -260,12 +262,12 @@ class session { // NOLINT
260262
/**
261263
* @brief Signing key id
262264
*/
263-
std::string signing_key_id;
265+
dpp::snowflake signing_key_id;
264266

265267
/**
266268
* @brief The bot's user snowflake ID
267269
*/
268-
std::string bot_user_id;
270+
dpp::snowflake bot_user_id;
269271

270272
/**
271273
* @brief The bot's key pair context

src/dpp/voice/enabled/handle_frame.cpp

+13-15
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,15 @@ void discord_voice_client::update_ratchets(bool force) {
5050
*/
5151
log(ll_debug, "Updating MLS ratchets for " + std::to_string(dave_mls_user_list.size() + 1) + " user(s)");
5252
for (const auto& user : dave_mls_user_list) {
53-
dpp::snowflake u{user};
54-
if (u == creator->me.id) {
53+
if (user == creator->me.id) {
5554
continue;
5655
}
5756
decryptor_list::iterator decryptor;
5857
/* New user join/old user leave - insert new ratchets if they don't exist */
59-
decryptor = mls_state->decryptors.find(u);
58+
decryptor = mls_state->decryptors.find(user.str());
6059
if (decryptor == mls_state->decryptors.end()) {
61-
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
62-
auto [iter, inserted] = mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::decryptor>(*creator));
60+
log(ll_debug, "Inserting decryptor key ratchet for NEW user: " + user.str() + ", protocol version: " + std::to_string(mls_state->dave_session->get_protocol_version()));
61+
auto [iter, inserted] = mls_state->decryptors.emplace(user.str(), std::make_unique<dpp::dave::decryptor>(*creator));
6362
decryptor = iter;
6463
}
6564
decryptor->second->transition_to_key_ratchet(mls_state->dave_session->get_key_ratchet(user), RATCHET_EXPIRY);
@@ -72,7 +71,7 @@ void discord_voice_client::update_ratchets(bool force) {
7271
if (mls_state->encryptor) {
7372
/* Updating key rachet should always be done on execute transition. Generally after group member add/remove. */
7473
log(ll_debug, "Setting key ratchet for sending audio...");
75-
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id.str()));
74+
mls_state->encryptor->set_key_ratchet(mls_state->dave_session->get_key_ratchet(creator->me.id));
7675
}
7776

7877
/**
@@ -146,7 +145,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
146145
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));
147146

148147
/* We should always recognize our own selves, but do we? */
149-
dave_mls_user_list.insert(this->creator->me.id.str());
148+
dave_mls_user_list.insert(this->creator->me.id);
150149

151150
auto r = mls_state->dave_session->process_welcome(dave_header.get_data(), dave_mls_user_list);
152151

@@ -222,7 +221,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
222221

223222
/* Remove this user from pending remove list if exist */
224223
for (const auto &user : joining_dave_users) {
225-
dave_mls_pending_remove_list.erase(user);
224+
dave_mls_pending_remove_list.erase(dpp::snowflake(user));
226225
}
227226

228227
log(ll_debug, "New of clients in voice channel: " + std::to_string(joining_dave_users.size()) + " total is " + std::to_string(dave_mls_user_list.size()));
@@ -298,7 +297,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
298297
}
299298

300299
/* Mark this user for remove on immediate upgrade */
301-
dave_mls_pending_remove_list.insert(u_id.str());
300+
dave_mls_pending_remove_list.insert(u_id);
302301

303302
if (!creator->on_voice_client_disconnect.empty()) {
304303
voice_client_disconnect_t vcd(nullptr, data);
@@ -575,12 +574,12 @@ void discord_voice_client::reinit_dave_mls_group() {
575574
if (mls_state->dave_session == nullptr) {
576575
mls_state->dave_session = std::make_unique<dave::mls::session>(
577576
*creator,
578-
nullptr, "", [this](std::string const &s1, std::string const &s2) {
577+
nullptr, snowflake(), [this](std::string const &s1, std::string const &s2) {
579578
log(ll_debug, "DAVE: " + s1 + ", " + s2);
580579
});
581580
}
582581

583-
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id.str(), mls_state->mls_key);
582+
mls_state->dave_session->init(dave::max_protocol_version(), channel_id, creator->me.id, mls_state->mls_key);
584583

585584
auto key_response = mls_state->dave_session->get_marshalled_key_package();
586585
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
@@ -630,12 +629,11 @@ void discord_voice_client::process_mls_group_rosters(const dave::roster_map &rma
630629
}
631630

632631
dpp::snowflake u_id(k);
633-
auto u_id_str = u_id.str();
634632

635-
log(ll_debug, "Removed user from MLS Group: " + u_id_str);
633+
log(ll_debug, "Removed user from MLS Group: " + u_id.str());
636634

637-
dave_mls_user_list.erase(u_id_str);
638-
dave_mls_pending_remove_list.erase(u_id_str);
635+
dave_mls_user_list.erase(u_id);
636+
dave_mls_pending_remove_list.erase(u_id);
639637

640638
/* Remove this user's key ratchet */
641639
mls_state->decryptors.erase(u_id);

0 commit comments

Comments
 (0)