@@ -77,6 +77,7 @@ static std::string external_ip;
77
77
struct dave_transient_key {
78
78
std::shared_ptr<::mlspp::SignaturePrivateKey> mls_key;
79
79
std::vector<uint8_t > cached_commit;
80
+ uint64_t transition_id{0 };
80
81
};
81
82
82
83
struct dave_encryptors {
@@ -497,10 +498,20 @@ int discord_voice_client::udp_recv(char* data, size_t max_length)
497
498
return (int ) recv (this ->fd , data, (int )max_length, 0 );
498
499
}
499
500
501
+ uint16_t dave_binary_header_t::get_welcome_transition_id () const {
502
+ uint16_t transition{0 };
503
+ std::memcpy (&transition, package, sizeof (uint16_t ));
504
+ return ntohs (transition);
505
+ }
506
+
500
507
std::vector<uint8_t > dave_binary_header_t::get_data (size_t length) const {
501
508
return std::vector<uint8_t >(package, package + length - sizeof (dave_binary_header_t ));
502
509
}
503
510
511
+ std::vector<uint8_t > dave_binary_header_t::get_welcome_data (size_t length) const {
512
+ return std::vector<uint8_t >(package + sizeof (uint16_t ), package + length - sizeof (dave_binary_header_t ));
513
+ }
514
+
504
515
bool discord_voice_client::handle_frame (const std::string &data, ws_opcode opcode)
505
516
{
506
517
json j;
@@ -516,31 +527,14 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
516
527
case voice_client_dave_mls_external_sender: {
517
528
log (ll_debug, " voice_client_dave_mls_external_sender" );
518
529
519
-
520
- dave_session = std::make_unique<dave::mls::Session>(
521
- nullptr , sessionid, [this ](std::string const & s1, std::string const & s2) {
522
- log (ll_debug, " Dave session constructor callback: " + s1 + " , " + s2);
523
- });
524
-
525
530
dave_session->SetExternalSender (dave_header->get_data (data.length ()));
526
531
527
- transient_key = std::make_unique<dave_transient_key>();
528
- dave_session->Init (dave::MaxSupportedProtocolVersion (), channel_id, creator->me .id .str (), transient_key->mls_key );
529
-
530
532
encryptors = std::make_unique<dave_encryptors>();
531
533
encryptors->encryptor = std::make_unique<dave::Encryptor>();
532
534
/* *
533
535
* TODO: There should be one of these per user but only one of the encryptor, above
534
536
*/
535
537
encryptors->decryptor = std::make_unique<dave::Decryptor>();
536
-
537
- auto epoch = dave_session->GetLastEpochAuthenticator ();
538
-
539
- auto key_response = dave_session->GetMarshalledKeyPackage ();
540
- key_response.insert (key_response.begin (), voice_client_dave_mls_key_package);
541
- this ->write (std::string_view (reinterpret_cast <const char *>(key_response.data ()), key_response.size ()), OP_BINARY);
542
-
543
- encryptors->encryptor ->SetKeyRatchet (dave_session->GetKeyRatchet (creator->me .id .str ()));
544
538
}
545
539
break ;
546
540
case voice_client_dave_mls_proposals: {
@@ -562,20 +556,30 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
562
556
log (ll_debug, " Setting decryptor key ratchet for user: " + user + " , protocol version: " + std::to_string (dave_session->GetProtocolVersion ()));
563
557
encryptors->decryptor ->TransitionToKeyRatchet (dave_session->GetKeyRatchet (user));
564
558
}
559
+ encryptors->encryptor ->SetKeyRatchet (dave_session->GetKeyRatchet (creator->me .id .str ()));
560
+
561
+ /* *
562
+ * https://www.ietf.org/archive/id/draft-ietf-mls-protocol-14.html#name-epoch-authenticators
563
+ * 9.7. Epoch Authenticators
564
+ * The main MLS key schedule provides a per-epoch epoch_authenticator. If one member of the group is being impersonated by an active attacker,
565
+ * the epoch_authenticator computed by their client will differ from those computed by the other group members.
566
+ */
567
+ auto epoch = dave_session->GetLastEpochAuthenticator ();
568
+ log (ll_debug, " DAVE epoch authenticator: " + dpp::base64_encode ((unsigned char const *)epoch.data (), epoch.size ()));
565
569
}
566
570
break ;
567
571
case voice_client_dave_mls_welcome: {
568
- log (ll_debug, " voice_client_dave_mls_welcome" );
569
- auto user_list_with_me = dave_mls_user_list;
570
- user_list_with_me.emplace (creator->me .id .str ());
571
- for (const auto & user : user_list_with_me) {
572
- std::cout << " USER: " << user << " \n " ;
572
+ this ->transient_key ->transition_id = dave_header->get_welcome_transition_id ();
573
+ log (ll_debug, " voice_client_dave_mls_welcome with transition id " + std::to_string (this ->transient_key ->transition_id ));
574
+ auto r = dave_session->ProcessWelcome (dave_header->get_welcome_data (data.length ()), dave_mls_user_list);
575
+ if (r.has_value ()) {
576
+ for (const auto & user_key_pair : r.value ()) {
577
+ std::cout << " WEL: " << user_key_pair.first << " \n " ;
578
+ }
579
+ encryptors->encryptor ->SetKeyRatchet (dave_session->GetKeyRatchet (creator->me .id .str ()));
580
+ } else {
581
+ std::cout << " Welcome has no value\n " ;
573
582
}
574
- auto r = dave_session->ProcessWelcome (dave_header->get_data (data.length ()), user_list_with_me);
575
- }
576
- break ;
577
- case voice_client_dave_mls_invalid_commit_welcome: {
578
- log (ll_debug, " voice_client_dave_mls_invalid_commit_welcome" );
579
583
}
580
584
break ;
581
585
default :
@@ -632,6 +636,43 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
632
636
log (ll_debug, " Number of clients in voice channel: " + std::to_string (dave_mls_user_list.size ()));
633
637
}
634
638
break ;
639
+ case voice_client_dave_mls_invalid_commit_welcome: {
640
+ this ->transient_key ->transition_id = j[" d" ][" transition_id" ];
641
+ log (ll_debug, " voice_client_dave_mls_invalid_commit_welcome transition id " + std::to_string (this ->transient_key ->transition_id ));
642
+ }
643
+ break ;
644
+ case voice_client_dave_execute_transition: {
645
+ log (ll_debug, " voice_client_dave_execute_transition" );
646
+ this ->transient_key ->transition_id = j[" d" ][" transition_id" ];
647
+ json obj = {
648
+ { " op" , voice_client_dave_transition_ready },
649
+ {
650
+ " d" ,
651
+ {
652
+ { " transition_id" , this ->transient_key ->transition_id },
653
+ }
654
+ }
655
+ };
656
+ this ->write (obj.dump (-1 , ' ' , false , json::error_handler_t ::replace), OP_TEXT);
657
+ }
658
+ break ;
659
+ /* "The protocol only uses this opcode to indicate when a downgrade to protocol version 0 is upcoming." */
660
+ case voice_client_dave_prepare_transition: {
661
+ uint64_t transition_id = j[" d" ][" transition_id" ];
662
+ uint64_t protocol_version = j[" d" ][" protocol_version" ];
663
+ log (ll_debug, " voice_client_dave_prepare_transition version=" + std::to_string (protocol_version) + " for transition " + std::to_string (transition_id));
664
+ }
665
+ break ;
666
+ case voice_client_dave_prepare_epoch: {
667
+ uint64_t protocol_version = j[" d" ][" protocol_version" ];
668
+ uint64_t epoch = j[" d" ][" epoch" ];
669
+ log (ll_debug, " voice_client_dave_prepare_epoch version=" + std::to_string (protocol_version) + " for epoch " + std::to_string (epoch));
670
+ if (epoch == 1 ) {
671
+ dave_session->Reset ();
672
+ dave_session->Init (dave::MaxSupportedProtocolVersion (), channel_id, creator->me .id .str (), transient_key->mls_key );
673
+ }
674
+ }
675
+ break ;
635
676
/* Client Disconnect */
636
677
case voice_opcode_client_disconnect: {
637
678
if (j.find (" d" ) != j.end () && j[" d" ].find (" user_id" ) != j[" d" ].end () && !j[" d" ][" user_id" ].is_null ()) {
@@ -739,6 +780,17 @@ bool discord_voice_client::handle_frame(const std::string &data, ws_opcode opcod
739
780
dave_version = dave_version_none;
740
781
send_silence (20 );
741
782
}
783
+
784
+ dave_session = std::make_unique<dave::mls::Session>(
785
+ nullptr , " " /* sessionid */ , [this ](std::string const & s1, std::string const & s2) {
786
+ log (ll_debug, " Dave session constructor callback: " + s1 + " , " + s2);
787
+ });
788
+ transient_key = std::make_unique<dave_transient_key>();
789
+ dave_session->Init (dave::MaxSupportedProtocolVersion (), channel_id, creator->me .id .str (), transient_key->mls_key );
790
+ auto key_response = dave_session->GetMarshalledKeyPackage ();
791
+ key_response.insert (key_response.begin (), voice_client_dave_mls_key_package);
792
+ this ->write (std::string_view (reinterpret_cast <const char *>(key_response.data ()), key_response.size ()), OP_BINARY);
793
+
742
794
} else {
743
795
/* This is needed to start voice receiving and make sure that the start of sending isn't cut off */
744
796
send_silence (20 );
0 commit comments