diff --git a/config/gobetween.toml b/config/gobetween.toml index 6c950f98..3de69f2e 100644 --- a/config/gobetween.toml +++ b/config/gobetween.toml @@ -189,7 +189,7 @@ protocol = "udp" ## For more details on PROXYPROTOCOL see https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt # # [servers.default.proxy_protocol] # (optional) -# version = "1" # (required) proxy protocol version. only "1" for now. +# version = "1" # (required) proxy protocol version, both "1" and "2" are allowed. # ## -------------------- healthchecks ------------------------- # # diff --git a/go.sum b/go.sum index 1553ace7..46a83135 100644 --- a/go.sum +++ b/go.sum @@ -169,8 +169,8 @@ github.com/opencontainers/runc v0.1.1 h1:GlxAyO6x8rfZYN9Tt0Kti5a/cP41iuiO2yYT0IJ github.com/opencontainers/runc v0.1.1/go.mod h1:qT5XzbpPznkRYVz/mWwUaVBUv2rmF59PVA73FjuZG0U= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c h1:Lgl0gzECD8GnQ5QCWA8o6BtfL6mDH5rQgM4/fX3avOs= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= -github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc h1:lNOt1SMsgHXTdpuGw+RpnJtzUcCb/oRKZP65pBy9pr8= -github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc/go.mod h1:6/gX3+E/IYGa0wMORlSMla999awQFdbaeQCHjSMKIzY= +github.com/pires/go-proxyproto v0.1.3 h1:2XEuhsQluSNA5QIQkiUv8PfgZ51sNYIQkq/yFquiSQM= +github.com/pires/go-proxyproto v0.1.3/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/src/core/context.go b/src/core/context.go index ab7931f1..5c9862e7 100644 --- a/src/core/context.go +++ b/src/core/context.go @@ -6,7 +6,10 @@ package core * @author Yaroslav Pogrebnyak */ -import "net" +import ( + "crypto/tls" + "net" +) type Context interface { String() string @@ -23,7 +26,8 @@ type TcpContext struct { /** * Current client connection */ - Conn net.Conn + Conn net.Conn + TlsState *tls.ConnectionState } func (t TcpContext) String() string { diff --git a/src/go.mod b/src/go.mod index 1eced251..b70f40c8 100644 --- a/src/go.mod +++ b/src/go.mod @@ -28,7 +28,7 @@ require ( github.com/miekg/dns v1.0.14 github.com/mitchellh/go-testing-interface v1.0.0 // indirect github.com/mitchellh/mapstructure v1.1.2 // indirect - github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc + github.com/pires/go-proxyproto v0.1.3 github.com/prometheus/client_golang v0.9.2 github.com/rogpeppe/fastuuid v1.0.0 // indirect github.com/sirupsen/logrus v1.4.0 diff --git a/src/go.sum b/src/go.sum index 77bb7f9c..86af4446 100644 --- a/src/go.sum +++ b/src/go.sum @@ -169,6 +169,8 @@ github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c h1:Lgl0gzECD8GnQ5 github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc h1:lNOt1SMsgHXTdpuGw+RpnJtzUcCb/oRKZP65pBy9pr8= github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc/go.mod h1:6/gX3+E/IYGa0wMORlSMla999awQFdbaeQCHjSMKIzY= +github.com/pires/go-proxyproto v0.1.3 h1:2XEuhsQluSNA5QIQkiUv8PfgZ51sNYIQkq/yFquiSQM= +github.com/pires/go-proxyproto v0.1.3/go.mod h1:Odh9VFOZJCf9G8cLW5o435Xf1J95Jw9Gw5rnCjcwzAY= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/src/manager/manager.go b/src/manager/manager.go index a203a6dd..d30d1c27 100644 --- a/src/manager/manager.go +++ b/src/manager/manager.go @@ -363,15 +363,15 @@ func prepareConfig(name string, server config.Server, defaults config.Connection if server.ProxyProtocol != nil { - if server.Protocol != "tcp" { - return config.Server{}, errors.New("proxy_protocol may be used only with 'tcp' protocol, not with " + server.Protocol) + if server.Protocol != "tcp" && server.Protocol != "tls" { + return config.Server{}, errors.New("proxy_protocol may be used only with 'tcp' or 'tls' protocols, not with " + server.Protocol) } if server.ProxyProtocol.Version == "" { return config.Server{}, errors.New("version field for proxy_protocol is not specified") } - if server.ProxyProtocol.Version != "1" { + if server.ProxyProtocol.Version != "1" && server.ProxyProtocol.Version != "2" { return config.Server{}, errors.New("Unsupported proxy_protocol version " + server.ProxyProtocol.Version) } } diff --git a/src/server/tcp/server.go b/src/server/tcp/server.go index 6c936771..50062ff9 100644 --- a/src/server/tcp/server.go +++ b/src/server/tcp/server.go @@ -242,12 +242,26 @@ func (this *Server) wrap(conn net.Conn, sniEnabled bool) { } if this.tlsConfig != nil { - conn = tls.Server(conn, this.tlsConfig) - } + tlsConn := tls.Server(conn, this.tlsConfig) + err = tlsConn.Handshake() + if err != nil { + log.Error("Failed to complete TLS handshake: ", err) + conn.Close() + return + } - this.connect <- &core.TcpContext{ - hostname, - conn, + tlsState := tlsConn.ConnectionState() + this.connect <- &core.TcpContext{ + hostname, + tlsConn, + &tlsState, + } + } else { + this.connect <- &core.TcpContext{ + hostname, + conn, + nil, + } } } @@ -336,7 +350,14 @@ func (this *Server) handle(ctx *core.TcpContext) { switch this.cfg.ProxyProtocol.Version { case "1": log.Debug("Sending proxy_protocol v1 header ", clientConn.RemoteAddr(), " -> ", this.listener.Addr(), " -> ", backendConn.RemoteAddr()) - err := proxyprotocol.SendProxyProtocolV1(clientConn, backendConn) + err := proxyprotocol.SendProxyProtocolV1(this.cfg.ProxyProtocol, ctx, backendConn) + if err != nil { + log.Error(err) + return + } + case "2": + log.Debug("Sending proxy_protocol v2 header ", clientConn.RemoteAddr(), " -> ", this.listener.Addr(), " -> ", backendConn.RemoteAddr()) + err := proxyprotocol.SendProxyProtocolV2(this.cfg.ProxyProtocol, ctx, backendConn) if err != nil { log.Error(err) return diff --git a/src/utils/proxyprotocol/proxyprotocol.go b/src/utils/proxyprotocol/proxyprotocol.go index ea726c95..258dca9d 100644 --- a/src/utils/proxyprotocol/proxyprotocol.go +++ b/src/utils/proxyprotocol/proxyprotocol.go @@ -6,6 +6,8 @@ import ( "strconv" proxyproto "github.com/pires/go-proxyproto" + "github.com/yyyar/gobetween/config" + "github.com/yyyar/gobetween/core" ) func addrToIPAndPort(addr net.Addr) (ip net.IP, port uint16, err error) { @@ -30,7 +32,8 @@ func addrToIPAndPort(addr net.Addr) (ip net.IP, port uint16, err error) { /// SendProxyProtocolV1 sends a proxy protocol v1 header to initialize the connection /// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt -func SendProxyProtocolV1(client net.Conn, backend net.Conn) error { +func SendProxyProtocolV1(cfg *config.ProxyProtocol, ctx *core.TcpContext, backend net.Conn) error { + client := ctx.Conn sourceIP, sourcePort, err := addrToIPAndPort(client.RemoteAddr()) if err != nil { return err @@ -60,3 +63,43 @@ func SendProxyProtocolV1(client net.Conn, backend net.Conn) error { } return nil } + +/// SendProxyProtocolV2 sends a proxy protocol v2 header to initialize the connection +/// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt +func SendProxyProtocolV2(cfg *config.ProxyProtocol, ctx *core.TcpContext, backend net.Conn) error { + client := ctx.Conn + sourceIP, sourcePort, err := addrToIPAndPort(client.RemoteAddr()) + if err != nil { + return err + } + + destinationIP, destinationPort, err := addrToIPAndPort(client.LocalAddr()) + if err != nil { + return err + } + + h := proxyproto.Header{ + Version: 2, + Command: proxyproto.PROXY, + SourceAddress: sourceIP, + SourcePort: sourcePort, + DestinationAddress: destinationIP, + DestinationPort: destinationPort, + } + if sourceIP.To4() != nil { + h.TransportProtocol = proxyproto.TCPv4 + } else { + h.TransportProtocol = proxyproto.TCPv6 + } + + if ctx.TlsState != nil { + // SSL TLV should be appended to header here, but go-proxyproto current + // version does not include any method that performs this operation + } + + _, err = h.WriteTo(backend) + if err != nil { + return nil + } + return nil +} diff --git a/test/proxyprotocol_test.go b/test/proxyprotocol_test.go new file mode 100644 index 00000000..0445144b --- /dev/null +++ b/test/proxyprotocol_test.go @@ -0,0 +1,119 @@ +package test + +import ( + "bytes" + "encoding/binary" + "io/ioutil" + "net" + "strconv" + "testing" + + "github.com/yyyar/gobetween/core" + "github.com/yyyar/gobetween/utils/proxyprotocol" +) + +func testSendProxyProtocol(t *testing.T, addr string, version string) (serverPort, clientPort string, received []byte) { + listener, err := net.Listen("tcp", addr+":0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + _, serverPort, err = net.SplitHostPort(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + go func() { + client, err := net.Dial("tcp", addr+":"+serverPort) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + _, clientPort, err = net.SplitHostPort(client.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + + ctx := &core.TcpContext{ + addr, + client, + nil, + } + + switch version { + case "1": + proxyprotocol.SendProxyProtocolV1(nil, ctx, client) + case "2": + proxyprotocol.SendProxyProtocolV2(nil, ctx, client) + default: + t.Fatalf("Unsupported proxy_protocol version " + version + ", aborting connection") + } + }() + + server, err := listener.Accept() + if err != nil { + t.Fatal(err) + } + defer server.Close() + + buf, err := ioutil.ReadAll(server) + if err != nil { + t.Fatal(err) + } + + received = []byte(buf) + + return serverPort, clientPort, received +} + +func TestSendProxyProtocolV1IPv4(t *testing.T) { + serverPort, clientPort, received := testSendProxyProtocol(t, "127.0.0.1", "1") + + expected := "PROXY TCP4 127.0.0.1 127.0.0.1 " + serverPort + " " + clientPort + "\r\n" + if string(received) != expected { + t.Fatalf("%s != %s", string(received), expected) + } +} + +func TestSendProxyProtocolV1IPv6(t *testing.T) { + serverPort, clientPort, received := testSendProxyProtocol(t, "[::1]", "1") + + expected := "PROXY TCP6 ::1 ::1 " + serverPort + " " + clientPort + "\r\n" + if string(received) != expected { + t.Fatalf("%s != %s", string(received), expected) + } +} + +func TestSendProxyProtocolV2IPv4(t *testing.T) { + serverPort, clientPort, received := testSendProxyProtocol(t, "127.0.0.1", "2") + + serverPortInt, _ := strconv.Atoi(serverPort) + clientPortInt, _ := strconv.Atoi(clientPort) + + expected := new(bytes.Buffer) + expected.Write([]byte{13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 17, 0, 12, 127, 0, 0, 1, 127, 0, 0, 1}) + binary.Write(expected, binary.BigEndian, uint16(serverPortInt)) + binary.Write(expected, binary.BigEndian, uint16(clientPortInt)) + + if bytes.Compare(received, expected.Bytes()) != 0 { + t.Fatalf("%v != %v", received, expected.Bytes()) + } +} + +func TestSendProxyProtocolV2IPv6(t *testing.T) { + serverPort, clientPort, received := testSendProxyProtocol(t, "[::1]", "2") + + serverPortInt, _ := strconv.Atoi(serverPort) + clientPortInt, _ := strconv.Atoi(clientPort) + + expected := new(bytes.Buffer) + expected.Write([]byte{13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10, 33, 33, 0, 36, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}) + binary.Write(expected, binary.BigEndian, uint16(serverPortInt)) + binary.Write(expected, binary.BigEndian, uint16(clientPortInt)) + + if bytes.Compare(received, expected.Bytes()) != 0 { + t.Fatalf("%v != %v", received, expected.Bytes()) + } +}