Skip to content

Commit 029d248

Browse files
committed
add some basic test
Signed-off-by: Yang Keao <[email protected]>
1 parent 74fdcf7 commit 029d248

File tree

6 files changed

+199
-6
lines changed

6 files changed

+199
-6
lines changed

privilege/privileges/ldap/BUILD.bazel

+8-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,14 @@ go_library(
2121
go_test(
2222
name = "ldap_test",
2323
timeout = "short",
24-
srcs = ["ldap_common_test.go"],
24+
srcs = [
25+
"ldap_common_test.go",
26+
"mock_ldap_server_test.go",
27+
],
2528
embed = [":ldap"],
2629
flaky = True,
27-
deps = ["@com_github_stretchr_testify//require"],
30+
deps = [
31+
"@com_github_go_ldap_ldap_v3//:ldap",
32+
"@com_github_stretchr_testify//require",
33+
],
2834
)

privilege/privileges/ldap/ldap_common.go

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
143143
if err != nil {
144144
return nil, err
145145
}
146+
147+
// FIXME: if the connection in the pool is killed or timeout, re-initialize the connection.
146148
return conn.(*ldap.Conn), nil
147149
}
148150

privilege/privileges/ldap/ldap_common_test.go

+62-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,14 @@
1515
package ldap
1616

1717
import (
18-
"github.com/stretchr/testify/require"
18+
"context"
19+
"math/rand"
20+
"net"
1921
"testing"
22+
"time"
23+
24+
"github.com/go-ldap/ldap/v3"
25+
"github.com/stretchr/testify/require"
2026
)
2127

2228
const listenPortRangeStart = 11310
@@ -29,3 +35,58 @@ func TestCanonicalizeDN(t *testing.T) {
2935
require.Equal(t, impl.canonicalizeDN("yka", "cn=y,dc=ping,dc=cap"), "cn=y,dc=ping,dc=cap")
3036
require.Equal(t, impl.canonicalizeDN("yka", "+dc=ping,dc=cap"), "cn=yka,dc=ping,dc=cap")
3137
}
38+
39+
func getConnectionsWithinTimeout(t *testing.T, ch chan net.Conn, timeout time.Duration, count int) []net.Conn {
40+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
41+
defer cancel()
42+
43+
result := []net.Conn{}
44+
for count > 0 {
45+
count--
46+
47+
select {
48+
case conn := <-ch:
49+
result = append(result, conn)
50+
case <-ctx.Done():
51+
require.Fail(t, "fail to get connections")
52+
}
53+
}
54+
55+
return result
56+
}
57+
58+
func TestLDAPConnectionPool(t *testing.T) {
59+
// allocate a random port between the port range
60+
port := rand.Int()%listenPortRangeLength + listenPortRangeStart
61+
ldapServer := NewMockLDAPServer(t)
62+
ldapServer.Listen(port)
63+
defer ldapServer.Close()
64+
conns := ldapServer.GetConnections()
65+
66+
impl := &ldapAuthImpl{ldapServerHost: "localhost", ldapServerPort: port}
67+
impl.SetInitCapacity(256)
68+
impl.SetMaxCapacity(1024)
69+
conn, err := impl.getConnection()
70+
require.NoError(t, err)
71+
impl.putConnection(conn)
72+
73+
getConnectionsWithinTimeout(t, conns, time.Second, 1)
74+
75+
// test allocating 255 more connections
76+
var clientConnections []*ldap.Conn
77+
for i := 0; i < 256; i++ {
78+
conn, err := impl.getConnection()
79+
require.NoError(t, err)
80+
81+
clientConnections = append(clientConnections, conn)
82+
}
83+
getConnectionsWithinTimeout(t, conns, time.Second, 255)
84+
for _, conn := range clientConnections {
85+
impl.putConnection(conn)
86+
}
87+
88+
clientConnections = clientConnections[:]
89+
90+
// now, the max capacity is somehow meaningless
91+
// TODO: auto scalling the capacity of LDAP connection pool
92+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Copyright 2023 PingCAP, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package ldap
16+
17+
import (
18+
"fmt"
19+
"net"
20+
"sync"
21+
"sync/atomic"
22+
"testing"
23+
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
// MockLDAPServer is a mock LDAP server to help testing the LDAP authentication
28+
type MockLDAPServer struct {
29+
t *testing.T
30+
wg sync.WaitGroup
31+
32+
listener net.Listener
33+
closed atomic.Bool
34+
35+
connections chan net.Conn
36+
}
37+
38+
// NewMockLDAPServer creates the MockLDAPServer
39+
func NewMockLDAPServer(t *testing.T) *MockLDAPServer {
40+
return &MockLDAPServer{
41+
t: t,
42+
connections: make(chan net.Conn),
43+
}
44+
}
45+
46+
// Listen listens on the specific port
47+
func (s *MockLDAPServer) Listen(port int) {
48+
l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
49+
require.NoError(s.t, err)
50+
51+
s.wg.Add(1)
52+
go func() {
53+
for {
54+
conn, err := l.Accept()
55+
if err != nil {
56+
if s.closed.Load() {
57+
break
58+
}
59+
60+
require.NoError(s.t, err)
61+
}
62+
go func() {
63+
s.connections <- conn
64+
}()
65+
}
66+
s.wg.Done()
67+
}()
68+
69+
s.listener = l
70+
return
71+
}
72+
73+
// Close closes the listener
74+
func (s *MockLDAPServer) Close() {
75+
s.closed.Store(true)
76+
err := s.listener.Close()
77+
require.NoError(s.t, err)
78+
79+
s.wg.Wait()
80+
}
81+
82+
// GetConnections returns all live connections
83+
func (s *MockLDAPServer) GetConnections() chan net.Conn {
84+
return s.connections
85+
}

server/conn.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func (cc *clientConn) authSwitchRequest(ctx context.Context, plugin string) ([]b
203203
clientPlugin = mysql.AuthMySQLClearPassword
204204
}
205205
failpoint.Inject("FakeAuthSwitch", func() {
206-
failpoint.Return([]byte(plugin), nil)
206+
failpoint.Return([]byte(clientPlugin), nil)
207207
})
208208
enclen := 1 + len(clientPlugin) + 1 + len(cc.salt) + 1
209209
data := cc.alloc.AllocWithLen(4, enclen)

server/conn_test.go

+41-2
Original file line numberDiff line numberDiff line change
@@ -1630,13 +1630,13 @@ func TestAuthSessionTokenPlugin(t *testing.T) {
16301630
// create a token without TLS
16311631
tk1 := testkit.NewTestKitWithSession(t, store, tc.Session)
16321632
tc.Session.GetSessionVars().ConnectionInfo = cc.connectInfo()
1633-
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil)
1633+
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil, nil)
16341634
tk1.MustQuery("show session_states")
16351635

16361636
// create a token with TLS
16371637
cc.tlsConn = &tls.Conn{}
16381638
tc.Session.GetSessionVars().ConnectionInfo = cc.connectInfo()
1639-
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil)
1639+
tk1.Session().Auth(&auth.UserIdentity{Username: "auth_session_token", Hostname: "localhost"}, nil, nil, nil)
16401640
tk1.MustQuery("show session_states")
16411641

16421642
// create a token with UnixSocket
@@ -1986,3 +1986,42 @@ func TestProcessInfoForExecuteCommand(t *testing.T) {
19861986
0x0A, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}))
19871987
require.Equal(t, cc.ctx.Session.ShowProcess().Info, "select sum(col1) from t where col1 < ? and col1 > 100")
19881988
}
1989+
1990+
func TestLDAPAuthSwitch(t *testing.T) {
1991+
store := testkit.CreateMockStore(t)
1992+
cfg := newTestConfig()
1993+
cfg.Port = 0
1994+
cfg.Status.StatusPort = 0
1995+
drv := NewTiDBDriver(store)
1996+
srv, err := NewServer(cfg, drv)
1997+
require.NoError(t, err)
1998+
tk := testkit.NewTestKit(t, store)
1999+
tk.MustExec("CREATE USER test_simple_ldap IDENTIFIED WITH authentication_ldap_simple AS 'uid=test_simple_ldap,dc=example,dc=com'")
2000+
2001+
cc := &clientConn{
2002+
connectionID: 1,
2003+
alloc: arena.NewAllocator(1024),
2004+
chunkAlloc: chunk.NewAllocator(),
2005+
pkt: &packetIO{
2006+
bufWriter: bufio.NewWriter(bytes.NewBuffer(nil)),
2007+
},
2008+
server: srv,
2009+
user: "test_simple_ldap",
2010+
}
2011+
se, _ := session.CreateSession4Test(store)
2012+
tc := &TiDBContext{
2013+
Session: se,
2014+
stmts: make(map[int]*TiDBStatement),
2015+
}
2016+
cc.setCtx(tc)
2017+
cc.isUnixSocket = true
2018+
2019+
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/server/FakeAuthSwitch", "return(1)"))
2020+
respAuthSwitch, err := cc.checkAuthPlugin(context.Background(), &handshakeResponse41{
2021+
Capability: mysql.ClientProtocol41 | mysql.ClientPluginAuth,
2022+
User: "test_simple_ldap",
2023+
})
2024+
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/server/FakeAuthSwitch"))
2025+
require.NoError(t, err)
2026+
require.Equal(t, []byte(mysql.AuthMySQLClearPassword), respAuthSwitch)
2027+
}

0 commit comments

Comments
 (0)