Skip to content

Commit dbb8507

Browse files
committed
refactor: better log and args parser
1 parent 7022e1b commit dbb8507

File tree

6 files changed

+251
-177
lines changed

6 files changed

+251
-177
lines changed

Cargo.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ license = "MIT/Apache-2.0"
77
name = "shadow-tls"
88
readme = "README.md"
99
repository = "https://github.com/ihciah/shadow-tls"
10-
version = "0.2.12"
10+
version = "0.2.13"
1111

1212
[dependencies]
1313
monoio = {version = "0.0.9"}

src/client.rs

+64-28
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
use std::net::SocketAddr;
1+
use std::{rc::Rc, sync::Arc};
22

3-
use monoio::{io::Splitable, net::TcpStream};
3+
use monoio::{
4+
io::Splitable,
5+
net::{TcpListener, TcpStream},
6+
};
47
use monoio_rustls::TlsConnector;
58
use rand::seq::SliceRandom;
69
use rustls::{OwnedTrustAnchor, RootCertStore, ServerName};
710

811
use crate::{
912
stream::HashedReadStream,
1013
util::{copy_with_application_data, copy_without_application_data, mod_tcp_conn},
11-
Opts,
1214
};
1315

1416
/// ShadowTlsClient.
15-
pub struct ShadowTlsClient<A> {
17+
#[derive(Clone)]
18+
pub struct ShadowTlsClient<LA, TA> {
19+
listen_addr: Arc<LA>,
20+
target_addr: Arc<TA>,
1621
tls_connector: TlsConnector,
17-
server_names: TlsNames,
18-
address: A,
19-
password: String,
20-
opts: Opts,
22+
tls_names: Arc<TlsNames>,
23+
password: Arc<String>,
24+
nodelay: bool,
2125
}
2226

2327
#[derive(Clone, Debug, PartialEq)]
@@ -61,11 +65,20 @@ pub struct TlsExtConfig {
6165
}
6266

6367
impl TlsExtConfig {
68+
#[allow(unused)]
6469
pub fn new(alpn: Option<Vec<Vec<u8>>>) -> TlsExtConfig {
6570
TlsExtConfig { alpn }
6671
}
6772
}
6873

74+
impl From<Option<Vec<String>>> for TlsExtConfig {
75+
fn from(maybe_alpns: Option<Vec<String>>) -> Self {
76+
Self {
77+
alpn: maybe_alpns.map(|alpns| alpns.into_iter().map(Into::into).collect()),
78+
}
79+
}
80+
}
81+
6982
impl std::fmt::Display for TlsExtConfig {
7083
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
7184
match self.alpn.as_ref() {
@@ -84,14 +97,15 @@ impl std::fmt::Display for TlsExtConfig {
8497
}
8598
}
8699

87-
impl<A> ShadowTlsClient<A> {
100+
impl<LA, TA> ShadowTlsClient<LA, TA> {
88101
/// Create new ShadowTlsClient.
89102
pub fn new(
90-
server_names: TlsNames,
91-
address: A,
92-
password: String,
93-
opts: Opts,
103+
listen_addr: LA,
104+
target_addr: TA,
105+
tls_names: TlsNames,
94106
tls_ext_config: TlsExtConfig,
107+
password: String,
108+
nodelay: bool,
95109
) -> anyhow::Result<Self> {
96110
let mut root_store = RootCertStore::empty();
97111
root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
@@ -115,22 +129,45 @@ impl<A> ShadowTlsClient<A> {
115129
let tls_connector = TlsConnector::from(tls_config);
116130

117131
Ok(Self {
132+
listen_addr: Arc::new(listen_addr),
133+
target_addr: Arc::new(target_addr),
118134
tls_connector,
119-
server_names,
120-
address,
121-
password,
122-
opts,
135+
tls_names: Arc::new(tls_names),
136+
password: Arc::new(password),
137+
nodelay,
123138
})
124139
}
125140

141+
pub async fn serve(self) -> anyhow::Result<()>
142+
where
143+
LA: std::net::ToSocketAddrs + 'static,
144+
TA: std::net::ToSocketAddrs + 'static,
145+
{
146+
let listener = TcpListener::bind(self.listen_addr.as_ref())
147+
.map_err(|e| anyhow::anyhow!("bind failed, check if the port is used: {e}"))?;
148+
let shared = Rc::new(self);
149+
loop {
150+
match listener.accept().await {
151+
Ok((mut conn, addr)) => {
152+
tracing::info!("Accepted a connection from {addr}");
153+
let client = shared.clone();
154+
mod_tcp_conn(&mut conn, true, shared.nodelay);
155+
monoio::spawn(async move {
156+
let _ = client.relay(conn).await;
157+
tracing::info!("Relay for {addr} finished");
158+
});
159+
}
160+
Err(e) => {
161+
tracing::error!("Accept failed: {e}");
162+
}
163+
}
164+
}
165+
}
166+
126167
/// Establish connection with remote and relay data.
127-
pub async fn relay(
128-
&self,
129-
mut in_stream: TcpStream,
130-
in_stream_addr: SocketAddr,
131-
) -> anyhow::Result<()>
168+
async fn relay(&self, mut in_stream: TcpStream) -> anyhow::Result<()>
132169
where
133-
A: std::net::ToSocketAddrs,
170+
TA: std::net::ToSocketAddrs,
134171
{
135172
let (mut out_stream, hash, session) = self.connect().await?;
136173
let mut hash_8b = [0; 8];
@@ -143,20 +180,19 @@ impl<A> ShadowTlsClient<A> {
143180
copy_with_application_data(&mut in_r, &mut out_w, Some(hash_8b))
144181
);
145182
let (_, _) = (a?, b?);
146-
tracing::info!("Relay for {in_stream_addr} finished");
147183
Ok(())
148184
}
149185

150186
/// Connect remote, do handshaking and calculate HMAC.
151187
async fn connect(&self) -> anyhow::Result<(TcpStream, [u8; 20], rustls::ClientConnection)>
152188
where
153-
A: std::net::ToSocketAddrs,
189+
TA: std::net::ToSocketAddrs,
154190
{
155-
let mut stream = TcpStream::connect(&self.address).await?;
156-
mod_tcp_conn(&mut stream, true, !self.opts.disable_nodelay);
191+
let mut stream = TcpStream::connect(self.target_addr.as_ref()).await?;
192+
mod_tcp_conn(&mut stream, true, self.nodelay);
157193
tracing::debug!("tcp connected, start handshaking");
158194
let stream = HashedReadStream::new(stream, self.password.as_bytes())?;
159-
let endpoint = self.server_names.random_choose().clone();
195+
let endpoint = self.tls_names.random_choose().clone();
160196
let tls_stream = self.tls_connector.connect(endpoint, stream).await?;
161197
let (io, session) = tls_stream.into_parts();
162198
let hash = io.hash();

0 commit comments

Comments
 (0)