Skip to content

Commit 5afa97c

Browse files
move dave session into mls state
1 parent 89e412e commit 5afa97c

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

include/dpp/discordvoiceclient.h

+1-6
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ inline constexpr size_t send_audio_raw_max_length = 11520;
7272
inline constexpr size_t secret_key_size = 32;
7373

7474
struct dave_state;
75-
struct dave_encryptors;
7675

7776
/*
7877
* @brief For holding a moving average of the number of current voice users, for applying a smooth gain ramp.
@@ -413,9 +412,7 @@ class DPP_EXPORT discord_voice_client : public websocket_client
413412
*/
414413
OpusRepacketizer* repacketizer;
415414

416-
std::unique_ptr<dave::mls::Session> dave_session{};
417-
418-
std::unique_ptr<dave_state> mls_state{};
415+
std::unique_ptr<dave_state> mls_state;
419416

420417
#else
421418
/**
@@ -429,8 +426,6 @@ class DPP_EXPORT discord_voice_client : public websocket_client
429426
*/
430427
void* repacketizer;
431428

432-
std::unique_ptr<int> dave_session{};
433-
434429
std::unique_ptr<int> mls_state{};
435430
#endif
436431

src/dpp/discordvoiceclient.cpp

+21-20
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ constexpr uint8_t voice_protocol_version = 8;
7575
static std::string external_ip;
7676

7777
struct dave_state {
78+
std::unique_ptr<dave::mls::Session> dave_session{};
7879
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
7980
std::vector<uint8_t> cached_commit;
8081
uint64_t transition_id{0};
@@ -538,14 +539,14 @@ void discord_voice_client::get_user_privacy_code(const dpp::snowflake user, priv
538539
callback("");
539540
return;
540541
}
541-
dave_session->GetPairwiseFingerprint(0x0000, user.str(), [callback](const std::vector<uint8_t>& data) {
542+
mls_state->dave_session->GetPairwiseFingerprint(0x0000, user.str(), [callback](const std::vector<uint8_t>& data) {
542543
std::cout << dpp::utility::debug_dump((uint8_t*)data.data(), data.size());
543544
callback(data.size() == 64 ? generate_displayable_code(data, 45) : "");
544545
});
545546
}
546547

547548
bool discord_voice_client::is_end_to_end_encrypted() const {
548-
return dave_session && mls_state && !mls_state->privacy_code.empty();
549+
return mls_state && !mls_state->privacy_code.empty();
549550
}
550551

551552
bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcode) {
@@ -562,7 +563,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
562563
case voice_client_dave_mls_external_sender: {
563564
log(ll_debug, "voice_client_dave_mls_external_sender");
564565

565-
dave_session->SetExternalSender(dave_header->get_data(data.length()));
566+
mls_state->dave_session->SetExternalSender(dave_header->get_data(data.length()));
566567

567568
mls_state->encryptor = std::make_unique<dave::Encryptor>();
568569
mls_state->decryptors.clear();
@@ -571,7 +572,7 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
571572
case voice_client_dave_mls_proposals: {
572573
log(ll_debug, "voice_client_dave_mls_proposals");
573574

574-
std::optional<std::vector<uint8_t>> response = dave_session->ProcessProposals(dave_header->get_data(data.length()), dave_mls_user_list);
575+
std::optional<std::vector<uint8_t>> response = mls_state->dave_session->ProcessProposals(dave_header->get_data(data.length()), dave_mls_user_list);
575576
if (response.has_value()) {
576577
auto r = response.value();
577578
mls_state->cached_commit = r;
@@ -582,39 +583,39 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
582583
break;
583584
case voice_client_dave_announce_commit_transaction: {
584585
log(ll_debug, "voice_client_dave_announce_commit_transaction");
585-
auto r = dave_session->ProcessCommit(mls_state->cached_commit);
586+
auto r = mls_state->dave_session->ProcessCommit(mls_state->cached_commit);
586587
for (const auto& user : dave_mls_user_list) {
587-
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
588+
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->GetProtocolVersion()));
588589
dpp::snowflake u{user};
589590
mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::Decryptor>());
590-
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
591+
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(mls_state->dave_session->GetKeyRatchet(user));
591592
}
592-
mls_state->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
593+
mls_state->encryptor->SetKeyRatchet(mls_state->dave_session->GetKeyRatchet(creator->me.id.str()));
593594

594595
/**
595596
* https://www.ietf.org/archive/id/draft-ietf-mls-protocol-14.html#name-epoch-authenticators
596597
* 9.7. Epoch Authenticators
597598
* The main MLS key schedule provides a per-epoch epoch_authenticator. If one member of the group is being impersonated by an active attacker,
598599
* the epoch_authenticator computed by their client will differ from those computed by the other group members.
599600
*/
600-
mls_state->privacy_code = generate_displayable_code(dave_session->GetLastEpochAuthenticator());
601+
mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->GetLastEpochAuthenticator());
601602
log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code);
602603
}
603604
break;
604605
case voice_client_dave_mls_welcome: {
605606
this->mls_state->transition_id = dave_header->get_welcome_transition_id();
606607
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->mls_state->transition_id));
607-
auto r = dave_session->ProcessWelcome(dave_header->get_welcome_data(data.length()), dave_mls_user_list);
608+
auto r = mls_state->dave_session->ProcessWelcome(dave_header->get_welcome_data(data.length()), dave_mls_user_list);
608609
if (r.has_value()) {
609610
for (const auto& user : dave_mls_user_list) {
610-
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
611+
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(mls_state->dave_session->GetProtocolVersion()));
611612
dpp::snowflake u{user};
612613
mls_state->decryptors.emplace(u, std::make_unique<dpp::dave::Decryptor>());
613-
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
614+
mls_state->decryptors.find(u)->second->TransitionToKeyRatchet(mls_state->dave_session->GetKeyRatchet(user));
614615
}
615-
mls_state->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
616+
mls_state->encryptor->SetKeyRatchet(mls_state->dave_session->GetKeyRatchet(creator->me.id.str()));
616617
}
617-
mls_state->privacy_code = generate_displayable_code(dave_session->GetLastEpochAuthenticator());
618+
mls_state->privacy_code = generate_displayable_code(mls_state->dave_session->GetLastEpochAuthenticator());
618619
log(ll_debug, "E2EE Privacy Code: " + mls_state->privacy_code);
619620
}
620621
break;
@@ -704,8 +705,8 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
704705
uint64_t epoch = j["d"]["epoch"];
705706
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch));
706707
if (epoch == 1) {
707-
dave_session->Reset();
708-
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
708+
mls_state->dave_session->Reset();
709+
mls_state->dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
709710
}
710711
}
711712
break;
@@ -817,13 +818,13 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
817818
send_silence(20);
818819
}
819820

820-
dave_session = std::make_unique<dave::mls::Session>(
821+
mls_state = std::make_unique<dave_state>();
822+
mls_state->dave_session = std::make_unique<dave::mls::Session>(
821823
nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) {
822824
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
823825
});
824-
mls_state = std::make_unique<dave_state>();
825-
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
826-
auto key_response = dave_session->GetMarshalledKeyPackage();
826+
mls_state->dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), mls_state->mls_key);
827+
auto key_response = mls_state->dave_session->GetMarshalledKeyPackage();
827828
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
828829
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
829830

0 commit comments

Comments
 (0)