Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions src/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, sync::Arc, time::Duration};
use std::{fmt::Display, time::Duration};

use bitcoin::{FeeRate, Network};
use p2p::{
Expand Down Expand Up @@ -68,7 +68,7 @@ impl ConnectionConfig {
self.network
}

/// Request the peer gossip new addresses at the beginning of the conneciton
/// Request the peer gossip new addresses at the beginning of the connection
pub fn request_addr(mut self) -> Self {
self.request_addr = true;
self
Expand Down Expand Up @@ -172,10 +172,9 @@ impl ConnectionConfig {
net_time_difference,
reported_height: version.start_height,
};
let preferences = Arc::new(Preferences::default());
let handshake = InitializedHandshake {
feeler,
their_preferences: preferences,
their_preferences: Preferences::default(),
send_cmpct: self.send_cmpct,
fee_filter: self.fee_filter,
request_addr: self.request_addr,
Expand All @@ -193,15 +192,15 @@ impl Default for ConnectionConfig {
#[derive(Debug, Clone)]
pub(crate) struct InitializedHandshake {
feeler: FeelerData,
their_preferences: Arc<Preferences>,
their_preferences: Preferences,
fee_filter: FeeRate,
send_cmpct: SendCmpct,
request_addr: bool,
}

impl InitializedHandshake {
pub(crate) fn negotiate(
&self,
&mut self,
message: NetworkMessage,
) -> Result<Option<(CompletedHandshake, Vec<NetworkMessage>)>, Error> {
match message {
Expand All @@ -216,25 +215,25 @@ impl InitializedHandshake {
Ok(Some((
CompletedHandshake {
feeler: self.feeler,
their_preferences: Arc::clone(&self.their_preferences),
their_preferences: self.their_preferences,
},
messages,
)))
}
NetworkMessage::WtxidRelay => {
self.their_preferences.prefers_wtxid();
self.their_preferences.sendwtxid = true;
Ok(None)
}
NetworkMessage::SendAddrV2 => {
self.their_preferences.prefers_addrv2();
self.their_preferences.sendaddrv2 = true;
Ok(None)
}
NetworkMessage::SendCmpct(cmpct) => {
self.their_preferences.prefers_cmpct(cmpct.version);
self.their_preferences.sendcmpct = cmpct;
Ok(None)
}
NetworkMessage::SendHeaders => {
self.their_preferences.prefers_header_announcment();
self.their_preferences.sendheaders = true;
Ok(None)
}
e => Err(Error::IrrelevantMessage(e.command())),
Expand All @@ -245,7 +244,7 @@ impl InitializedHandshake {
#[derive(Debug, Clone)]
pub(crate) struct CompletedHandshake {
pub(crate) feeler: FeelerData,
pub(crate) their_preferences: Arc<Preferences>,
pub(crate) their_preferences: Preferences,
}

/// Errors that occur during a handshake
Expand Down Expand Up @@ -307,7 +306,7 @@ mod tests {
let connection_config = ConnectionConfig::new();
let nonce = 43;
let system_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let (init_handshake, messages) = connection_config
let (mut init_handshake, messages) = connection_config
.start_handshake(system_time, NetworkMessage::Version(mock), nonce)
.unwrap();
let mut message_iter = messages.into_iter();
Expand All @@ -333,9 +332,9 @@ mod tests {
assert!(matches!(cmpct, NetworkMessage::SendCmpct(_)));
let fee_filter = message_iter.next().unwrap();
assert!(matches!(fee_filter, NetworkMessage::FeeFilter(_)));
assert!(completed.their_preferences.wtxid());
assert!(completed.their_preferences.addrv2());
assert!(!completed.their_preferences.announce_by_headers());
assert!(completed.their_preferences.sendwtxid);
assert!(completed.their_preferences.sendaddrv2);
assert!(!completed.their_preferences.sendheaders);
}

#[test]
Expand Down Expand Up @@ -382,7 +381,7 @@ mod tests {
let connection_config = ConnectionConfig::new().request_addr();
let nonce = 43;
let system_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap();
let (init_handshake, _) = connection_config
let (mut init_handshake, _) = connection_config
.start_handshake(system_time, NetworkMessage::Version(mock), nonce)
.unwrap();
let (_, messages) = init_handshake
Expand Down
90 changes: 24 additions & 66 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@
#![warn(missing_docs)]
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex,
},
sync::{Arc, Mutex},
time::{Duration, Instant},
};

use bitcoin::Network;
use p2p::{ProtocolVersion, ServiceFlags};
use p2p::{message_compact_blocks::SendCmpct, ProtocolVersion, ServiceFlags};

pub extern crate p2p;

Expand All @@ -37,59 +34,30 @@ pub struct FeelerData {

/// The peer's preferences during this connection. These are updated automatically as the peer
/// shares information.
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub struct Preferences {
sendheaders: AtomicBool,
sendaddrv2: AtomicBool,
sendcmpct: AtomicU64,
sendwtxid: AtomicBool,
/// Announce blocks to this peer by block header.
pub sendheaders: bool,
/// Send `Addrv2` addresses.
pub sendaddrv2: bool,
/// Compact block relay preferences.
pub sendcmpct: SendCmpct,
/// Advertise transactions by WTXID.
pub sendwtxid: bool,
}

impl Preferences {
fn new() -> Self {
Self {
sendheaders: AtomicBool::new(false),
sendaddrv2: AtomicBool::new(false),
sendcmpct: AtomicU64::new(0),
sendwtxid: AtomicBool::new(false),
sendheaders: false,
sendaddrv2: false,
sendcmpct: SendCmpct {
send_compact: false,
version: 0x00,
},
sendwtxid: false,
}
}

fn prefers_header_announcment(&self) {
self.sendheaders.store(true, Ordering::Relaxed);
}

fn prefers_addrv2(&self) {
self.sendaddrv2.store(true, Ordering::Relaxed);
}

fn prefers_wtxid(&self) {
self.sendwtxid.store(true, Ordering::Relaxed);
}

fn prefers_cmpct(&self, version: u64) {
self.sendcmpct.store(version, Ordering::Relaxed);
}

/// The peer prefers `addrv2` messages
pub fn addrv2(&self) -> bool {
self.sendaddrv2.load(Ordering::Relaxed)
}

/// The peer prefers headers are announced by a `headers` message instead of `inv`
pub fn announce_by_headers(&self) -> bool {
self.sendheaders.load(Ordering::Relaxed)
}

/// The peer prefers witness transaction IDs
pub fn wtxid(&self) -> bool {
self.sendwtxid.load(Ordering::Relaxed)
}

/// The reported compact block relay version
pub fn cmpct_version(&self) -> u64 {
self.sendcmpct.load(Ordering::Relaxed)
}
}

impl Default for Preferences {
Expand All @@ -102,7 +70,7 @@ impl Default for Preferences {
#[derive(Debug, Clone)]
pub struct ConnectionMetrics {
feeler: FeelerData,
their_preferences: Arc<Preferences>,
their_preferences: Arc<Mutex<Preferences>>,
timed_messages: Arc<Mutex<TimedMessages>>,
start_time: Instant,
outbound_ping_state: Arc<Mutex<OutboundPing>>,
Expand All @@ -114,9 +82,10 @@ impl ConnectionMetrics {
&self.feeler
}

/// Their current preferences for message exchange
pub fn their_preferences(&self) -> &Preferences {
self.their_preferences.as_ref()
/// Their current preferences for message exchange, if not currently being mutated.
pub fn their_preferences(&self) -> Option<Preferences> {
let pref = self.their_preferences.lock().ok();
pref.as_deref().copied()
}

/// The message rate for a time-sensitive message
Expand Down Expand Up @@ -300,7 +269,7 @@ impl SeedsExt for Network {
mod tests {
use std::time::{Duration, Instant};

use crate::{MessageRate, Preferences, TimedMessage, TimedMessages};
use crate::{MessageRate, TimedMessage, TimedMessages};

#[test]
fn test_message_rate() {
Expand All @@ -316,17 +285,6 @@ mod tests {
assert_eq!(rate.messages_per_secs(later).unwrap(), 2.);
}

#[test]
fn test_preferences() {
let pref = Preferences::new();
pref.prefers_wtxid();
pref.prefers_addrv2();
pref.prefers_header_announcment();
assert!(pref.addrv2());
assert!(pref.announce_by_headers());
assert!(pref.wtxid());
}

#[test]
fn test_timed_messages() {
let now = Instant::now();
Expand Down
19 changes: 13 additions & 6 deletions src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl ConnectionExt for ConnectionConfig {
let mut write_half = WriteTransport::V1(self.network().default_network_magic());
let mut read_half = ReadTransport::V1(self.network().default_network_magic());
write_half.write_message(NetworkMessage::Version(version), &mut tcp_stream)?;
let (handshake, messages) = match read_half.read_message(&mut tcp_stream)? {
let (mut handshake, messages) = match read_half.read_message(&mut tcp_stream)? {
Some(message) => self.start_handshake(unix_time, message, nonce)?,
None => return Err(Error::MissingVersion),
};
Expand All @@ -116,9 +116,10 @@ impl ConnectionExt for ConnectionConfig {
feeler,
their_preferences,
} = completed_handshake;
let arc_pref = Arc::new(Mutex::new(their_preferences));
let live_connection = ConnectionMetrics {
feeler,
their_preferences: Arc::clone(&their_preferences),
their_preferences: Arc::clone(&arc_pref),
timed_messages: Arc::clone(&timed_messages),
start_time: Instant::now(),
outbound_ping_state: Arc::clone(&outbound_ping),
Expand All @@ -141,7 +142,7 @@ impl ConnectionExt for ConnectionConfig {
let reader = ConnectionReader {
tcp_stream,
transport: read_half,
their_preferences,
their_preferences: Arc::clone(&arc_pref),
timed_messages,
outbound_ping_state: Arc::clone(&outbound_ping),
};
Expand Down Expand Up @@ -276,7 +277,7 @@ impl OpenWriter {
pub struct ConnectionReader {
tcp_stream: TcpStream,
transport: ReadTransport,
their_preferences: Arc<Preferences>,
their_preferences: Arc<Mutex<Preferences>>,
timed_messages: Arc<Mutex<TimedMessages>>,
outbound_ping_state: Arc<Mutex<OutboundPing>>,
}
Expand All @@ -287,9 +288,15 @@ impl ConnectionReader {
let message = self.transport.read_message(&mut self.tcp_stream)?;
if let Some(message) = &message {
match message {
NetworkMessage::SendHeaders => self.their_preferences.prefers_header_announcment(),
NetworkMessage::SendHeaders => {
if let Ok(mut lock) = self.their_preferences.lock() {
lock.sendheaders = true;
}
}
NetworkMessage::SendCmpct(cmpct) => {
self.their_preferences.prefers_cmpct(cmpct.version)
if let Ok(mut lock) = self.their_preferences.lock() {
lock.sendcmpct = *cmpct;
}
}
NetworkMessage::Block(_) => {
if let Ok(mut lock) = self.timed_messages.lock() {
Expand Down
2 changes: 1 addition & 1 deletion tests/std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ fn can_accept_handshake() {
.connect(bind)
.start();
let (_, _, metadata) = wait.join().unwrap().unwrap();
assert!(metadata.their_preferences().wtxid());
assert!(metadata.their_preferences().unwrap().sendwtxid);
}

#[test]
Expand Down
Loading