1
- use std:: net :: SocketAddr ;
1
+ use std:: { rc :: Rc , sync :: Arc } ;
2
2
3
- use monoio:: { io:: Splitable , net:: TcpStream } ;
3
+ use monoio:: {
4
+ io:: Splitable ,
5
+ net:: { TcpListener , TcpStream } ,
6
+ } ;
4
7
use monoio_rustls:: TlsConnector ;
5
8
use rand:: seq:: SliceRandom ;
6
9
use rustls:: { OwnedTrustAnchor , RootCertStore , ServerName } ;
7
10
8
11
use crate :: {
9
12
stream:: HashedReadStream ,
10
13
util:: { copy_with_application_data, copy_without_application_data, mod_tcp_conn} ,
11
- Opts ,
12
14
} ;
13
15
14
16
/// ShadowTlsClient.
15
- pub struct ShadowTlsClient < A > {
17
+ #[ derive( Clone ) ]
18
+ pub struct ShadowTlsClient < LA , TA > {
19
+ listen_addr : Arc < LA > ,
20
+ target_addr : Arc < TA > ,
16
21
tls_connector : TlsConnector ,
17
- server_names : TlsNames ,
18
- address : A ,
19
- password : String ,
20
- opts : Opts ,
22
+ tls_names : Arc < TlsNames > ,
23
+ password : Arc < String > ,
24
+ nodelay : bool ,
21
25
}
22
26
23
27
#[ derive( Clone , Debug , PartialEq ) ]
@@ -61,11 +65,20 @@ pub struct TlsExtConfig {
61
65
}
62
66
63
67
impl TlsExtConfig {
68
+ #[ allow( unused) ]
64
69
pub fn new ( alpn : Option < Vec < Vec < u8 > > > ) -> TlsExtConfig {
65
70
TlsExtConfig { alpn }
66
71
}
67
72
}
68
73
74
+ impl From < Option < Vec < String > > > for TlsExtConfig {
75
+ fn from ( maybe_alpns : Option < Vec < String > > ) -> Self {
76
+ Self {
77
+ alpn : maybe_alpns. map ( |alpns| alpns. into_iter ( ) . map ( Into :: into) . collect ( ) ) ,
78
+ }
79
+ }
80
+ }
81
+
69
82
impl std:: fmt:: Display for TlsExtConfig {
70
83
fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
71
84
match self . alpn . as_ref ( ) {
@@ -84,14 +97,15 @@ impl std::fmt::Display for TlsExtConfig {
84
97
}
85
98
}
86
99
87
- impl < A > ShadowTlsClient < A > {
100
+ impl < LA , TA > ShadowTlsClient < LA , TA > {
88
101
/// Create new ShadowTlsClient.
89
102
pub fn new (
90
- server_names : TlsNames ,
91
- address : A ,
92
- password : String ,
93
- opts : Opts ,
103
+ listen_addr : LA ,
104
+ target_addr : TA ,
105
+ tls_names : TlsNames ,
94
106
tls_ext_config : TlsExtConfig ,
107
+ password : String ,
108
+ nodelay : bool ,
95
109
) -> anyhow:: Result < Self > {
96
110
let mut root_store = RootCertStore :: empty ( ) ;
97
111
root_store. add_server_trust_anchors ( webpki_roots:: TLS_SERVER_ROOTS . 0 . iter ( ) . map ( |ta| {
@@ -115,22 +129,45 @@ impl<A> ShadowTlsClient<A> {
115
129
let tls_connector = TlsConnector :: from ( tls_config) ;
116
130
117
131
Ok ( Self {
132
+ listen_addr : Arc :: new ( listen_addr) ,
133
+ target_addr : Arc :: new ( target_addr) ,
118
134
tls_connector,
119
- server_names,
120
- address,
121
- password,
122
- opts,
135
+ tls_names : Arc :: new ( tls_names) ,
136
+ password : Arc :: new ( password) ,
137
+ nodelay,
123
138
} )
124
139
}
125
140
141
+ pub async fn serve ( self ) -> anyhow:: Result < ( ) >
142
+ where
143
+ LA : std:: net:: ToSocketAddrs + ' static ,
144
+ TA : std:: net:: ToSocketAddrs + ' static ,
145
+ {
146
+ let listener = TcpListener :: bind ( self . listen_addr . as_ref ( ) )
147
+ . map_err ( |e| anyhow:: anyhow!( "bind failed, check if the port is used: {e}" ) ) ?;
148
+ let shared = Rc :: new ( self ) ;
149
+ loop {
150
+ match listener. accept ( ) . await {
151
+ Ok ( ( mut conn, addr) ) => {
152
+ tracing:: info!( "Accepted a connection from {addr}" ) ;
153
+ let client = shared. clone ( ) ;
154
+ mod_tcp_conn ( & mut conn, true , shared. nodelay ) ;
155
+ monoio:: spawn ( async move {
156
+ let _ = client. relay ( conn) . await ;
157
+ tracing:: info!( "Relay for {addr} finished" ) ;
158
+ } ) ;
159
+ }
160
+ Err ( e) => {
161
+ tracing:: error!( "Accept failed: {e}" ) ;
162
+ }
163
+ }
164
+ }
165
+ }
166
+
126
167
/// Establish connection with remote and relay data.
127
- pub async fn relay (
128
- & self ,
129
- mut in_stream : TcpStream ,
130
- in_stream_addr : SocketAddr ,
131
- ) -> anyhow:: Result < ( ) >
168
+ async fn relay ( & self , mut in_stream : TcpStream ) -> anyhow:: Result < ( ) >
132
169
where
133
- A : std:: net:: ToSocketAddrs ,
170
+ TA : std:: net:: ToSocketAddrs ,
134
171
{
135
172
let ( mut out_stream, hash, session) = self . connect ( ) . await ?;
136
173
let mut hash_8b = [ 0 ; 8 ] ;
@@ -143,20 +180,19 @@ impl<A> ShadowTlsClient<A> {
143
180
copy_with_application_data( & mut in_r, & mut out_w, Some ( hash_8b) )
144
181
) ;
145
182
let ( _, _) = ( a?, b?) ;
146
- tracing:: info!( "Relay for {in_stream_addr} finished" ) ;
147
183
Ok ( ( ) )
148
184
}
149
185
150
186
/// Connect remote, do handshaking and calculate HMAC.
151
187
async fn connect ( & self ) -> anyhow:: Result < ( TcpStream , [ u8 ; 20 ] , rustls:: ClientConnection ) >
152
188
where
153
- A : std:: net:: ToSocketAddrs ,
189
+ TA : std:: net:: ToSocketAddrs ,
154
190
{
155
- let mut stream = TcpStream :: connect ( & self . address ) . await ?;
156
- mod_tcp_conn ( & mut stream, true , ! self . opts . disable_nodelay ) ;
191
+ let mut stream = TcpStream :: connect ( self . target_addr . as_ref ( ) ) . await ?;
192
+ mod_tcp_conn ( & mut stream, true , self . nodelay ) ;
157
193
tracing:: debug!( "tcp connected, start handshaking" ) ;
158
194
let stream = HashedReadStream :: new ( stream, self . password . as_bytes ( ) ) ?;
159
- let endpoint = self . server_names . random_choose ( ) . clone ( ) ;
195
+ let endpoint = self . tls_names . random_choose ( ) . clone ( ) ;
160
196
let tls_stream = self . tls_connector . connect ( endpoint, stream) . await ?;
161
197
let ( io, session) = tls_stream. into_parts ( ) ;
162
198
let hash = io. hash ( ) ;
0 commit comments