forked from c653labs/pggateway
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlistener.go
158 lines (133 loc) · 3.25 KB
/
listener.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
package pggateway
import (
"crypto/tls"
"github.com/c653labs/pgproto"
"io"
"net"
)
type Listener struct {
l net.Listener
config *ListenerConfig
plugins *PluginRegistry
stopping bool
}
func NewListener(config *ListenerConfig) *Listener {
return &Listener{
config: config,
stopping: false,
}
}
func (l *Listener) Listen() error {
l.stopping = false
var err error
l.plugins, err = NewPluginRegistry(l.config.Authentication, l.config.Logging)
if err != nil {
return err
}
l.l, err = net.Listen("tcp", l.config.Bind)
if err != nil {
return err
}
return nil
}
func (l *Listener) Close() error {
l.stopping = true
if l.l != nil {
l.l.Close()
}
return nil
}
func (l *Listener) Handle() error {
for {
conn, err := l.l.Accept()
if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() {
continue
}
if err != nil {
if l.stopping {
return nil
}
l.plugins.LogError(nil, "error accepting client: %s", err)
return err
}
go func(conn net.Conn) {
defer conn.Close()
err := l.handleClient(conn)
if err != nil && err != io.EOF {
l.plugins.LogError(nil, "error handling client session: %s", err)
}
}(conn)
}
}
func (l *Listener) handleClient(client net.Conn) error {
var err error
var startup *pgproto.StartupMessage
var isSSL bool
startup, err = pgproto.ParseStartupMessage(client)
if err != nil {
return err
}
if startup.SSLRequest {
if !l.config.SSL.Enabled {
_, err = client.Write([]byte{'N'})
return err
}
client, err = l.upgradeSSLConnection(client)
if err != nil {
return err
}
isSSL = true
startup, err = pgproto.ParseStartupMessage(client)
if err != nil {
return err
}
} else if l.config.SSL.Required {
// SSL is required but they didn't request it, return an error
return RetunErrorfAndWritePGMsg(client, "server does not support SSL, but SSL was required")
}
var user []byte
var database []byte
var ok bool
if user, ok = startup.Options["user"]; !ok {
// No username was provided
return RetunErrorfAndWritePGMsg(client, "user startup option is required")
}
if database, ok = startup.Options["database"]; !ok {
// No database was provided
return RetunErrorfAndWritePGMsg(client, "database startup option is required")
}
sess, err := NewSession(startup, user, database, isSSL, client, nil, l.plugins)
if err != nil {
l.plugins.LogError(nil, "error creating new client session: %s", err)
client.Close()
return err
}
defer sess.Close()
l.plugins.LogInfo(sess.loggingContext(), "new client session")
err = sess.Handle()
if err != nil && err != io.EOF {
l.plugins.LogError(sess.loggingContext(), "client session end: %s", err)
} else {
l.plugins.LogInfo(sess.loggingContext(), "client session end")
}
return err
}
func (l *Listener) upgradeSSLConnection(client net.Conn) (net.Conn, error) {
_, err := client.Write([]byte{'S'})
if err != nil {
return nil, err
}
cer, err := tls.LoadX509KeyPair(l.config.SSL.Certificate, l.config.SSL.Key)
if err != nil {
return nil, err
}
// Upgrade the client connection to a TLS connection
sslClient := tls.Server(client, &tls.Config{
Certificates: []tls.Certificate{cer},
})
err = sslClient.Handshake()
return sslClient, err
}
func (l *Listener) String() string {
return l.config.Bind
}