Skip to content

Commit 49d4d0f

Browse files
now joins MLS group in both flows
1 parent 6188fc5 commit 49d4d0f

5 files changed

+104
-59
lines changed

include/dpp/discordvoiceclient.h

+24-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,30 @@ struct dave_binary_header_t {
175175
*/
176176
uint8_t package[];
177177

178-
std::vector<uint8_t> get_data(size_t length) const;
178+
/**
179+
* Get the data package from the packed binary frame, as a vector of uint8_t
180+
* for use in the libdave functions
181+
*
182+
* @param length Length of the data, use the websocket frame size here
183+
* @return data blob
184+
*/
185+
[[nodiscard]] std::vector<uint8_t> get_data(size_t length) const;
186+
187+
/**
188+
* Get the data package from the packed binary frame for ProcessWelcome,
189+
* as a vector of uint8_t for use in the libdave functions.
190+
*
191+
* @param length Length of the data, use the websocket frame size here
192+
* @return data blob
193+
*/
194+
[[nodiscard]] std::vector<uint8_t> get_welcome_data(size_t length) const;
195+
196+
/**
197+
* Get transition ID for ProcessWelcome
198+
*
199+
* @return Transition ID
200+
*/
201+
[[nodiscard]] uint16_t get_welcome_transition_id() const;
179202
};
180203
#pragma pack(pop)
181204

src/dpp/dave/persisted_key_pair.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ static std::shared_ptr<::mlspp::SignaturePrivateKey> GetPersistedKeyPair(
3838
std::string id = MakeKeyID(sessionID, suite);
3939

4040
if (auto it = map.find(id); it != map.end()) {
41-
std::cout << "5\n";
4241
return it->second;
4342
}
4443

src/dpp/dave/persisted_key_pair_generic.cpp

-9
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,6 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
6868
std::string curstr;
6969
std::filesystem::path dir = GetKeyStorageDirectory();
7070

71-
std::cout << "KSD: " << dir << "\n";
72-
7371
if (dir.empty()) {
7472
DISCORD_LOG(LS_ERROR) << "Failed to determine key storage directory in GetPersistedKeyPair";
7573
return nullptr;
@@ -85,8 +83,6 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
8583

8684
std::filesystem::path file = dir / (id + ".key");
8785

88-
std::cout << "FILE: " << file << "\n";
89-
9086
if (std::filesystem::exists(file)) {
9187
std::ifstream ifs(file, std::ios_base::in | std::ios_base::binary);
9288
if (!ifs) {
@@ -102,8 +98,6 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
10298
return nullptr;
10399
}
104100

105-
std::cout << "CURSTR: " << curstr << "\n";
106-
107101
try {
108102
ret = ::mlspp::SignaturePrivateKey::from_jwk(suite, curstr);
109103
}
@@ -113,16 +107,13 @@ std::shared_ptr<::mlspp::SignaturePrivateKey> GetGenericPersistedKeyPair(KeyPair
113107
}
114108
}
115109
else {
116-
std::cout << "GEN NEW\n";
117110
ret = ::mlspp::SignaturePrivateKey::generate(suite);
118111

119112
std::string newstr = ret.to_jwk(suite);
120113

121114
std::filesystem::path tmpfile = file;
122115
tmpfile += ".tmp";
123116

124-
std::cout << "TMPFILE " << tmpfile << "\n";
125-
126117
#ifdef _WIN32
127118
int fd = _wopen(tmpfile.c_str(), _O_WRONLY | _O_CREAT | _O_TRUNC, _S_IREAD | _S_IWRITE);
128119
#else

src/dpp/dave/session.cpp

+1-21
Original file line numberDiff line numberDiff line change
@@ -332,57 +332,37 @@ try {
332332
DISCORD_LOG(LS_INFO) << "Processing commit";
333333
DISCORD_LOG(LS_INFO) << "Commit: " << ::mlspp::bytes_ns::bytes(commit);
334334

335-
std::cout << "1\n";
336-
337335
auto commitMessage = ::mlspp::tls::get<::mlspp::MLSMessage>(commit);
338336

339-
std::cout << "2\n";
340-
341337
if (!CanProcessCommit(commitMessage)) {
342338
DISCORD_LOG(LS_ERROR) << "ProcessCommit called with unprocessable MLS commit";
343339
return ignored_t{};
344340
}
345341

346-
std::cout << "3\n";
347-
348342
// in case we're the sender of this commit
349343
// we need to pull the cached state from our outbound cache
350344
std::optional<::mlspp::State> optionalCachedState = std::nullopt;
351345
if (outboundCachedGroupState_) {
352346
optionalCachedState = *(outboundCachedGroupState_.get());
353347
}
354348

355-
std::cout << "4\n";
356-
357349
auto newState = stateWithProposals_->handle(commitMessage, optionalCachedState);
358350

359-
std::cout << "5\n";
360-
361351
if (!newState) {
362352
DISCORD_LOG(LS_ERROR) << "MLS commit handling did not produce a new state";
363353
return failed_t{};
364354
}
365355

366-
std::cout << "6\n";
367-
368356
DISCORD_LOG(LS_INFO) << "Successfully processed MLS commit, updating state; our leaf index is "
369357
<< newState->index().val << "; current epoch is " << newState->epoch();
370358

371-
std::cout << "7\n";
372-
373359
RosterMap ret = ReplaceState(std::make_unique<::mlspp::State>(std::move(*newState)));
374360

375-
std::cout << "8\n";
376-
377361
// reset the outbound cached group since we handled the commit for this epoch
378362
outboundCachedGroupState_.reset();
379363

380-
std::cout << "9\n";
381-
382364
ClearPendingState();
383365

384-
std::cout << "10\n";
385-
386366
return ret;
387367
}
388368
catch (const std::exception& e) {
@@ -396,7 +376,7 @@ std::optional<RosterMap> Session::ProcessWelcome(
396376
std::set<std::string> const& recognizedUserIDs) noexcept
397377
try {
398378
if (!HasCryptographicStateForWelcome()) {
399-
DISCORD_LOG(LS_ERROR) << "Missing local cyrpto state necessary to process MLS welcome";
379+
DISCORD_LOG(LS_ERROR) << "Missing local crypto state necessary to process MLS welcome";
400380
return std::nullopt;
401381
}
402382

src/dpp/discordvoiceclient.cpp

+79-27
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ static std::string external_ip;
7777
struct dave_transient_key {
7878
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
7979
std::vector<uint8_t> cached_commit;
80+
uint64_t transition_id{0};
8081
};
8182

8283
struct dave_encryptors {
@@ -497,10 +498,20 @@ int discord_voice_client::udp_recv(char* data, size_t max_length)
497498
return (int) recv(this->fd, data, (int)max_length, 0);
498499
}
499500

501+
uint16_t dave_binary_header_t::get_welcome_transition_id() const {
502+
uint16_t transition{0};
503+
std::memcpy(&transition, package, sizeof(uint16_t));
504+
return ntohs(transition);
505+
}
506+
500507
std::vector<uint8_t> dave_binary_header_t::get_data(size_t length) const {
501508
return std::vector<uint8_t>(package, package + length - sizeof(dave_binary_header_t));
502509
}
503510

511+
std::vector<uint8_t> dave_binary_header_t::get_welcome_data(size_t length) const {
512+
return std::vector<uint8_t>(package + sizeof(uint16_t), package + length - sizeof(dave_binary_header_t));
513+
}
514+
504515
bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcode)
505516
{
506517
json j;
@@ -516,31 +527,14 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
516527
case voice_client_dave_mls_external_sender: {
517528
log(ll_debug, "voice_client_dave_mls_external_sender");
518529

519-
520-
dave_session = std::make_unique<dave::mls::Session>(
521-
nullptr, sessionid, [this](std::string const& s1, std::string const& s2) {
522-
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
523-
});
524-
525530
dave_session->SetExternalSender(dave_header->get_data(data.length()));
526531

527-
transient_key = std::make_unique<dave_transient_key>();
528-
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), transient_key->mls_key);
529-
530532
encryptors = std::make_unique<dave_encryptors>();
531533
encryptors->encryptor = std::make_unique<dave::Encryptor>();
532534
/**
533535
* TODO: There should be one of these per user but only one of the encryptor, above
534536
*/
535537
encryptors->decryptor = std::make_unique<dave::Decryptor>();
536-
537-
auto epoch = dave_session->GetLastEpochAuthenticator();
538-
539-
auto key_response = dave_session->GetMarshalledKeyPackage();
540-
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
541-
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
542-
543-
encryptors->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
544538
}
545539
break;
546540
case voice_client_dave_mls_proposals: {
@@ -562,20 +556,30 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
562556
log(ll_debug, "Setting decryptor key ratchet for user: " + user + ", protocol version: " + std::to_string(dave_session->GetProtocolVersion()));
563557
encryptors->decryptor->TransitionToKeyRatchet(dave_session->GetKeyRatchet(user));
564558
}
559+
encryptors->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
560+
561+
/**
562+
* https://www.ietf.org/archive/id/draft-ietf-mls-protocol-14.html#name-epoch-authenticators
563+
* 9.7. Epoch Authenticators
564+
* The main MLS key schedule provides a per-epoch epoch_authenticator. If one member of the group is being impersonated by an active attacker,
565+
* the epoch_authenticator computed by their client will differ from those computed by the other group members.
566+
*/
567+
auto epoch = dave_session->GetLastEpochAuthenticator();
568+
log(ll_debug, "DAVE epoch authenticator: " + dpp::base64_encode((unsigned char const*)epoch.data(), epoch.size()));
565569
}
566570
break;
567571
case voice_client_dave_mls_welcome: {
568-
log(ll_debug, "voice_client_dave_mls_welcome");
569-
auto user_list_with_me = dave_mls_user_list;
570-
user_list_with_me.emplace(creator->me.id.str());
571-
for (const auto& user : user_list_with_me) {
572-
std::cout << "USER: " << user << "\n";
572+
this->transient_key->transition_id = dave_header->get_welcome_transition_id();
573+
log(ll_debug, "voice_client_dave_mls_welcome with transition id " + std::to_string(this->transient_key->transition_id));
574+
auto r = dave_session->ProcessWelcome(dave_header->get_welcome_data(data.length()), dave_mls_user_list);
575+
if (r.has_value()) {
576+
for (const auto& user_key_pair : r.value()) {
577+
std::cout << "WEL: " << user_key_pair.first << "\n";
578+
}
579+
encryptors->encryptor->SetKeyRatchet(dave_session->GetKeyRatchet(creator->me.id.str()));
580+
} else {
581+
std::cout << "Welcome has no value\n";
573582
}
574-
auto r = dave_session->ProcessWelcome(dave_header->get_data(data.length()), user_list_with_me);
575-
}
576-
break;
577-
case voice_client_dave_mls_invalid_commit_welcome: {
578-
log(ll_debug, "voice_client_dave_mls_invalid_commit_welcome");
579583
}
580584
break;
581585
default:
@@ -632,6 +636,43 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
632636
log(ll_debug, "Number of clients in voice channel: " + std::to_string(dave_mls_user_list.size()));
633637
}
634638
break;
639+
case voice_client_dave_mls_invalid_commit_welcome: {
640+
this->transient_key->transition_id = j["d"]["transition_id"];
641+
log(ll_debug, "voice_client_dave_mls_invalid_commit_welcome transition id " + std::to_string(this->transient_key->transition_id));
642+
}
643+
break;
644+
case voice_client_dave_execute_transition: {
645+
log(ll_debug, "voice_client_dave_execute_transition");
646+
this->transient_key->transition_id = j["d"]["transition_id"];
647+
json obj = {
648+
{ "op", voice_client_dave_transition_ready },
649+
{
650+
"d",
651+
{
652+
{ "transition_id", this->transient_key->transition_id },
653+
}
654+
}
655+
};
656+
this->write(obj.dump(-1, ' ', false, json::error_handler_t::replace), OP_TEXT);
657+
}
658+
break;
659+
/* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */
660+
case voice_client_dave_prepare_transition: {
661+
uint64_t transition_id = j["d"]["transition_id"];
662+
uint64_t protocol_version = j["d"]["protocol_version"];
663+
log(ll_debug, "voice_client_dave_prepare_transition version=" + std::to_string(protocol_version) + " for transition " + std::to_string(transition_id));
664+
}
665+
break;
666+
case voice_client_dave_prepare_epoch: {
667+
uint64_t protocol_version = j["d"]["protocol_version"];
668+
uint64_t epoch = j["d"]["epoch"];
669+
log(ll_debug, "voice_client_dave_prepare_epoch version=" + std::to_string(protocol_version) + " for epoch " + std::to_string(epoch));
670+
if (epoch == 1) {
671+
dave_session->Reset();
672+
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), transient_key->mls_key);
673+
}
674+
}
675+
break;
635676
/* Client Disconnect */
636677
case voice_opcode_client_disconnect: {
637678
if (j.find("d") != j.end() && j["d"].find("user_id") != j["d"].end() && !j["d"]["user_id"].is_null()) {
@@ -739,6 +780,17 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
739780
dave_version = dave_version_none;
740781
send_silence(20);
741782
}
783+
784+
dave_session = std::make_unique<dave::mls::Session>(
785+
nullptr, "" /* sessionid */, [this](std::string const& s1, std::string const& s2) {
786+
log(ll_debug, "Dave session constructor callback: " + s1 + ", " + s2);
787+
});
788+
transient_key = std::make_unique<dave_transient_key>();
789+
dave_session->Init(dave::MaxSupportedProtocolVersion(), channel_id, creator->me.id.str(), transient_key->mls_key);
790+
auto key_response = dave_session->GetMarshalledKeyPackage();
791+
key_response.insert(key_response.begin(), voice_client_dave_mls_key_package);
792+
this->write(std::string_view(reinterpret_cast<const char*>(key_response.data()), key_response.size()), OP_BINARY);
793+
742794
} else {
743795
/* This is needed to start voice receiving and make sure that the start of sending isn't cut off */
744796
send_silence(20);

0 commit comments

Comments
 (0)