29
29
#include < iostream>
30
30
#include < mls/crypto.h>
31
31
#include < mls/messages.h>
32
+ #include < dpp/export.h>
33
+ #include < dpp/snowflake.h>
32
34
#include < mls/state.h>
33
35
#include < dpp/cluster.h>
34
36
#include " mls_key_ratchet.h"
@@ -50,20 +52,20 @@ struct queued_proposal {
50
52
::mlspp::bytes_ns::bytes ref;
51
53
};
52
54
53
- session::session (dpp::cluster& cluster, key_pair_context_type context, const std::string& auth_session_id, mls_failure_callback callback) noexcept
55
+ session::session (dpp::cluster& cluster, key_pair_context_type context, dpp::snowflake auth_session_id, mls_failure_callback callback) noexcept
54
56
: signing_key_id(auth_session_id), key_pair_context(context), failure_callback(std::move(callback)), creator(cluster)
55
57
{
56
58
creator.log (dpp::ll_debug, " Creating a new MLS session" );
57
59
}
58
60
59
61
session::~session () noexcept = default ;
60
62
61
- void session::init (protocol_version version, uint64_t group_id, std::string const & self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
63
+ void session::init (protocol_version version, dpp::snowflake group_id, dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept {
62
64
reset ();
63
65
64
66
bot_user_id = self_user_id;
65
67
66
- creator.log (dpp::ll_debug, " Initializing MLS session with protocol version " + std::to_string (version) + " and group ID " + std::to_string ( group_id));
68
+ creator.log (dpp::ll_debug, " Initializing MLS session with protocol version " + std::to_string (version) + " and group ID " + group_id. str ( ));
67
69
session_protocol_version = version;
68
70
session_group_id = std::move (big_endian_bytes_from (group_id).as_vec ());
69
71
@@ -123,7 +125,7 @@ catch (const std::exception& e) {
123
125
return ;
124
126
}
125
127
126
- std::optional<std::vector<uint8_t >> session::process_proposals (std::vector<uint8_t > proposals, std::set<std::string > const & recognised_user_ids) noexcept
128
+ std::optional<std::vector<uint8_t >> session::process_proposals (std::vector<uint8_t > proposals, std::set<dpp::snowflake > const & recognised_user_ids) noexcept
127
129
try {
128
130
if (!pending_group_state && !current_state) {
129
131
creator.log (dpp::ll_debug, " Cannot process proposals without any pending or established MLS group state" );
@@ -183,9 +185,7 @@ try {
183
185
for (const auto & proposal_message : messages) {
184
186
auto validated_content = state_with_proposals->unwrap (proposal_message);
185
187
186
- if (!validate_proposal_message (validated_content.authenticated_content (),
187
- *state_with_proposals,
188
- recognised_user_ids)) {
188
+ if (!validate_proposal_message (validated_content.authenticated_content (), *state_with_proposals, recognised_user_ids)) {
189
189
return std::nullopt;
190
190
}
191
191
@@ -238,9 +238,9 @@ catch (const std::exception& e) {
238
238
return std::nullopt;
239
239
}
240
240
241
- bool session::is_recognized_user_id (const ::mlspp::Credential& cred, std::set<std::string > const & recognised_user_ids) const
241
+ bool session::is_recognized_user_id (const ::mlspp::Credential& cred, std::set<dpp::snowflake > const & recognised_user_ids) const
242
242
{
243
- std::string uid = user_credential_to_string (cred, session_protocol_version);
243
+ dpp::snowflake uid ( user_credential_to_string (cred, session_protocol_version) );
244
244
if (uid.empty ()) {
245
245
creator.log (dpp::ll_warning, " Attempted to verify credential of unexpected type" );
246
246
return false ;
@@ -254,7 +254,7 @@ bool session::is_recognized_user_id(const ::mlspp::Credential& cred, std::set<st
254
254
return true ;
255
255
}
256
256
257
- bool session::validate_proposal_message (::mlspp::AuthenticatedContent const & message, ::mlspp::State const & target_state, std::set<std::string > const & recognised_user_ids) const {
257
+ bool session::validate_proposal_message (::mlspp::AuthenticatedContent const & message, ::mlspp::State const & target_state, std::set<dpp::snowflake > const & recognised_user_ids) const {
258
258
if (message.wire_format != ::mlspp::WireFormat::mls_public_message) {
259
259
creator.log (dpp::ll_warning, " MLS proposal message must be PublicMessage" );
260
260
TRACK_MLS_ERROR (" Invalid proposal wire format" );
@@ -357,7 +357,7 @@ catch (const std::exception& e) {
357
357
return failed_t {};
358
358
}
359
359
360
- std::optional<roster_map> session::process_welcome (std::vector<uint8_t > welcome, std::set<std::string > const & recognised_user_ids) noexcept
360
+ std::optional<roster_map> session::process_welcome (std::vector<uint8_t > welcome, std::set<dpp::snowflake > const & recognised_user_ids) noexcept
361
361
try {
362
362
if (!has_cryptographic_state_for_welcome ()) {
363
363
creator.log (dpp::ll_warning, " Missing local crypto state necessary to process MLS welcome" );
@@ -461,7 +461,7 @@ bool session::has_cryptographic_state_for_welcome() const noexcept
461
461
return join_key_package && join_init_private_key && signature_private_key && hpke_private_key;
462
462
}
463
463
464
- bool session::verify_welcome_state (::mlspp::State const & state, std::set<std::string > const & recognised_user_ids) const
464
+ bool session::verify_welcome_state (::mlspp::State const & state, std::set<dpp::snowflake > const & recognised_user_ids) const
465
465
{
466
466
if (!mls_external_sender) {
467
467
creator.log (dpp::ll_warning, " Cannot verify MLS welcome without an external sender" );
@@ -502,13 +502,13 @@ bool session::verify_welcome_state(::mlspp::State const& state, std::set<std::st
502
502
return true ;
503
503
}
504
504
505
- void session::init_leaf_node (std::string const & self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
505
+ void session::init_leaf_node (dpp::snowflake self_user_id, std::shared_ptr<::mlspp::SignaturePrivateKey>& transient_key) noexcept
506
506
try {
507
507
auto ciphersuite = ciphersuite_for_protocol_version (session_protocol_version);
508
508
509
509
if (!transient_key) {
510
510
if (!signing_key_id.empty ()) {
511
- transient_key = get_persisted_key_pair (creator, key_pair_context, signing_key_id, session_protocol_version);
511
+ transient_key = get_persisted_key_pair (creator, key_pair_context, signing_key_id. str () , session_protocol_version);
512
512
if (!transient_key) {
513
513
creator.log (dpp::ll_warning, " Did not receive MLS signature private key from get_persisted_key_pair; aborting" );
514
514
return ;
@@ -522,7 +522,7 @@ try {
522
522
523
523
signature_private_key = transient_key;
524
524
525
- auto self_credential = create_user_credential (self_user_id, session_protocol_version);
525
+ auto self_credential = create_user_credential (self_user_id. str () , session_protocol_version);
526
526
hpke_private_key = std::make_unique<::mlspp::HPKEPrivateKey>(::mlspp::HPKEPrivateKey::generate (ciphersuite));
527
527
self_leaf_node = std::make_unique<::mlspp::LeafNode>(
528
528
ciphersuite, hpke_private_key->public_key , signature_private_key->public_key , std::move (self_credential),
@@ -608,7 +608,7 @@ catch (const std::exception& e) {
608
608
return {};
609
609
}
610
610
611
- std::unique_ptr<key_ratchet_interface> session::get_key_ratchet (std::string const & user_id) const noexcept
611
+ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet (dpp::snowflake user_id) const noexcept
612
612
{
613
613
if (!current_state) {
614
614
creator.log (dpp::ll_warning, " Cannot get key ratchet without an established MLS group" );
@@ -617,7 +617,7 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
617
617
618
618
// change the string user ID to a little endian 64 bit user ID
619
619
// TODO: Make this use dpp::snowflake
620
- auto u64_user_id = strtoull ( user_id. c_str (), nullptr , 10 ) ;
620
+ uint64_t u64_user_id = user_id;
621
621
auto user_id_bytes = ::mlspp::bytes_ns::bytes (sizeof (u64_user_id));
622
622
memcpy (user_id_bytes.data (), &u64_user_id, sizeof (u64_user_id));
623
623
@@ -629,14 +629,14 @@ std::unique_ptr<key_ratchet_interface> session::get_key_ratchet(std::string cons
629
629
return std::make_unique<mls_key_ratchet>(creator, current_state->cipher_suite (), std::move (secret));
630
630
}
631
631
632
- void session::get_pairwise_fingerprint (uint16_t version, std::string const & user_id, pairwise_fingerprint_callback callback) const noexcept
632
+ void session::get_pairwise_fingerprint (uint16_t version, dpp::snowflake user_id, pairwise_fingerprint_callback callback) const noexcept
633
633
try {
634
634
if (!current_state || !signature_private_key) {
635
635
throw std::invalid_argument (" No established MLS group" );
636
636
}
637
637
638
- uint64_t remote_user_id = strtoull ( user_id. c_str (), nullptr , 10 ) ;
639
- uint64_t self_user_id = strtoull ( bot_user_id. c_str (), nullptr , 10 ) ;
638
+ uint64_t remote_user_id = user_id;
639
+ uint64_t self_user_id = bot_user_id;
640
640
641
641
auto it = roster.find (remote_user_id);
642
642
if (it == roster.end ()) {
@@ -687,16 +687,7 @@ try {
687
687
688
688
std::vector<uint8_t > out (hash_len);
689
689
690
- int ret = EVP_PBE_scrypt ((const char *)data.data (),
691
- data.size (),
692
- salt,
693
- sizeof (salt),
694
- N,
695
- r,
696
- p,
697
- max_mem,
698
- out.data (),
699
- out.size ());
690
+ int ret = EVP_PBE_scrypt ((const char *)data.data (), data.size (), salt, sizeof (salt), N, r, p, max_mem, out.data (), out.size ());
700
691
701
692
if (ret == 1 ) {
702
693
callback (out);
0 commit comments