-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
203 lines (181 loc) · 6.05 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
package main
import (
"flag"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"github.com/WofWca/snowflake-generalized/common"
"github.com/xtaci/smux"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/ptutil/safelog"
snowflakeServer "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/server/lib"
"golang.org/x/crypto/acme/autocert"
)
func main() {
var listenAddr string
var destinationAddr string
var destinationProtocol string
var acmeEmail string
var acmeHostnamesCommas string
var acmeCertCacheDir string
var disableTLS bool
// var logFilename string
var unsafeLogging bool
// var versionFlag bool
// For the original Snowflake server CLI parameters, see
// https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/blob/6d2011ded71dc53662fa0f256fbf9c3036c474a4/server/server.go#L139-144
flag.StringVar(
&listenAddr,
"listen-address",
"localhost:7901",
"Listen for proxy connections on this `address` and forward them to \"destination-addr\".\nSet to \":7901\" to listen on port 7901 on all interfaces.",
)
flag.StringVar(
&destinationAddr,
"destination-address",
"", // "localhost:1080", we probably should not have a default address for security reasons
"Forward client connections to this `address`.\nThis can also be a remote address.",
)
flag.StringVar(
&destinationProtocol,
"destination-protocol",
"tcp",
"what type of packets to send to the destination, "+
"i.e. what protocol the target application (WireGuard, SOCKS server) "+
" is using, \"udp\" or \"tcp\".\n",
)
flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications")
flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate")
flag.StringVar(&acmeCertCacheDir, "acme-cert-cache", "acme-cert-cache", "directory in which certificates should be cached")
flag.BoolVar(&disableTLS, "disable-tls", false, "don't use HTTPS")
// flag.StringVar(&logFilename, "log", "", "log file to write to")
flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
// flag.BoolVar(&versionFlag, "version", false, "display version info to stderr and quit")
flag.Parse()
if destinationProtocol != "tcp" && destinationProtocol != "udp" {
log.Fatal("`destination-protocol` must either be \"tcp\" or \"udp\"")
}
if destinationAddr == "" {
flag.Usage()
log.Fatalf("\"destination-address\" must be specified")
}
listenAddrStruct, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
log.Fatalf("error resolving listen address: %s", err.Error())
}
// var certManager *autocert.Manager = nil
var transport *snowflakeServer.Transport
if !disableTLS {
if acmeHostnamesCommas == "" {
log.Fatal("the --acme-hostnames option is required, unless --disable-tls")
}
acmeHostnames := strings.Split(acmeHostnamesCommas, ",")
log.Printf("ACME hostnames: %q", acmeHostnames)
var cache autocert.Cache
if acmeCertCacheDir != "" {
log.Printf("caching ACME certificates in directory %q", acmeCertCacheDir)
cache = autocert.DirCache(acmeCertCacheDir)
} else {
log.Printf("disabling ACME certificate cache: %s", err)
}
certManager := autocert.Manager{
Cache: cache,
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(acmeHostnames...),
Email: acmeEmail,
}
go func() {
log.Printf("Starting HTTP-01 listener")
log.Fatal(http.ListenAndServe(":80", certManager.HTTPHandler(nil)))
}()
transport = snowflakeServer.NewSnowflakeServer(certManager.GetCertificate)
} else {
transport = snowflakeServer.NewSnowflakeServer(nil)
}
numKCPInstances := 1
ln, err := transport.Listen(listenAddrStruct, numKCPInstances)
if err != nil {
log.Fatalf("error opening listener: %s", err.Error())
}
log.Printf(
"Listening for proxy connections on %v \"%v\" and forwarding them to \"%v\"",
destinationProtocol,
listenAddrStruct,
destinationAddr,
)
// Setting scrubber _after_ initial checks
// so that addresses are printed properly.
logOutput := os.Stdout
if unsafeLogging {
log.SetOutput(logOutput)
} else {
log.SetOutput(&safelog.LogScrubber{Output: logOutput})
}
for {
clientConn, err := ln.Accept()
if err != nil {
if err, ok := err.(net.Error); ok && err.Temporary() {
continue
}
log.Printf("Failed to accept proxy connection: %s", err)
// This will terminate the server.
// The original Snowflake server does the same:
// https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/blob/6d2011ded71dc53662fa0f256fbf9c3036c474a4/server/server.go#L99-111
break
}
log.Printf(
"Got Snowflake client connection! Forwarding to %v \"%v\"",
destinationProtocol,
destinationAddr,
)
go serveSnowflakeConnection(
&clientConn,
&destinationAddr,
&destinationProtocol,
)
}
}
// Closes the connection when it finishes serving it.
func serveSnowflakeConnection(
snowflakeConn *net.Conn,
destinationAddr *string,
destinationProtocol *string,
) {
defer (*snowflakeConn).Close()
smuxConfig := smux.DefaultConfig()
// Let's not close the connection on our own, and let Snowflake handle that.
smuxConfig.KeepAliveDisabled = true
muxSession, err := smux.Server(*snowflakeConn, smuxConfig)
if err != nil {
log.Print("Mux session open error", err)
return
}
defer muxSession.Close()
for {
stream, err := muxSession.AcceptStream()
if err != nil {
// Otherwise it's a regular connection close
// TODO or is it? There is `ErrTimeout`?
if err != io.ErrClosedPipe {
log.Print("AcceptStream error", err)
}
return
}
log.Print("New stream!", stream.ID())
go func() {
defer stream.Close()
destinationConn, err := net.Dial(*destinationProtocol, *destinationAddr)
if err != nil {
log.Print("Failed to dial destination address", err)
// Hmm should we also snowflakeConn.Close()
return
}
defer destinationConn.Close()
// TODO should we utilize `shutdownChan`?
shutdownChan := make(chan struct{})
common.CopyLoop(stream, destinationConn, shutdownChan)
}()
}
}