diff --git a/privilege/privileges/ldap/BUILD.bazel b/privilege/privileges/ldap/BUILD.bazel index a3aef02c23f25..a807291caf073 100644 --- a/privilege/privileges/ldap/BUILD.bazel +++ b/privilege/privileges/ldap/BUILD.bazel @@ -21,14 +21,8 @@ go_library( go_test( name = "ldap_test", timeout = "short", - srcs = [ - "ldap_common_test.go", - "mock_ldap_server_test.go", - ], + srcs = ["ldap_common_test.go"], embed = [":ldap"], flaky = True, - deps = [ - "@com_github_go_ldap_ldap_v3//:ldap", - "@com_github_stretchr_testify//require", - ], + deps = ["@com_github_stretchr_testify//require"], ) diff --git a/privilege/privileges/ldap/ldap_common.go b/privilege/privileges/ldap/ldap_common.go index 3a61447950cfa..36e06e2d5ac68 100644 --- a/privilege/privileges/ldap/ldap_common.go +++ b/privilege/privileges/ldap/ldap_common.go @@ -50,8 +50,10 @@ type ldapAuthImpl struct { ldapConnectionPool *pools.ResourcePool } -func (impl *ldapAuthImpl) searchUser(userName string) (string, error) { - l, err := impl.getConnection() +func (impl *ldapAuthImpl) searchUser(userName string) (dn string, err error) { + var l *ldap.Conn + + l, err = impl.getConnection() if err != nil { return "", err } @@ -61,6 +63,12 @@ func (impl *ldapAuthImpl) searchUser(userName string) (string, error) { if err != nil { return "", errors.Wrap(err, "bind root dn to search user") } + defer func() { + // bind to anonymous user + _, err = l.SimpleBind(&ldap.SimpleBindRequest{ + AllowEmptyPassword: true, + }) + }() result, err := l.Search(&ldap.SearchRequest{ BaseDN: impl.bindBaseDN, @@ -68,15 +76,15 @@ func (impl *ldapAuthImpl) searchUser(userName string) (string, error) { Filter: fmt.Sprintf("(%s=%s)", impl.searchAttr, userName), }) if err != nil { - return "", err + return } if len(result.Entries) == 0 { return "", errors.New("LDAP user not found") } - entry := result.Entries[0] - return entry.DN, nil + dn = result.Entries[0].DN + return } // canonicalizeDN turns the `dn` provided in database to the `dn` recognized by LDAP server @@ -138,14 +146,38 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) { return ldapConnection, nil } +const getConnectionMaxRetry = 10 + func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) { - conn, err := impl.ldapConnectionPool.Get() - if err != nil { - return nil, err - } + retryCount := 0 + for { + conn, err := impl.ldapConnectionPool.Get() + if err != nil { + return nil, err + } + + // try to bind anonymous user. It has two meanings: + // 1. Clear the state of previous binding, to avoid security leaks. (Though it's not serious, because even the current + // connection has binded to other users, the following authentication will still fail. But the ACL for anonymous + // user and a valid user could be different, so it's better to bind back to anonymous user here. + // 2. Detect whether this connection is still valid to use, in case the server has closed this connection. + ldapConnection := conn.(*ldap.Conn) + _, err = ldapConnection.SimpleBind(&ldap.SimpleBindRequest{ + AllowEmptyPassword: true, + }) + if err != nil { + // fail to bind to anonymous user, just release this connection and try to get a new one + impl.ldapConnectionPool.Put(nil) + + retryCount++ + if retryCount >= getConnectionMaxRetry { + return nil, errors.Wrap(err, "fail to bind to anonymous user") + } + continue + } - // FIXME: if the connection in the pool is killed or timeout, re-initialize the connection. - return conn.(*ldap.Conn), nil + return conn.(*ldap.Conn), nil + } } func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) { diff --git a/privilege/privileges/ldap/ldap_common_test.go b/privilege/privileges/ldap/ldap_common_test.go index 4d0d4dd2ccdbc..6229e9010aaa5 100644 --- a/privilege/privileges/ldap/ldap_common_test.go +++ b/privilege/privileges/ldap/ldap_common_test.go @@ -15,19 +15,11 @@ package ldap import ( - "context" - "math/rand" - "net" "testing" - "time" - "github.com/go-ldap/ldap/v3" "github.com/stretchr/testify/require" ) -const listenPortRangeStart = 11310 -const listenPortRangeLength = 256 - func TestCanonicalizeDN(t *testing.T) { impl := &ldapAuthImpl{ searchAttr: "cn", @@ -35,58 +27,3 @@ func TestCanonicalizeDN(t *testing.T) { require.Equal(t, impl.canonicalizeDN("yka", "cn=y,dc=ping,dc=cap"), "cn=y,dc=ping,dc=cap") require.Equal(t, impl.canonicalizeDN("yka", "+dc=ping,dc=cap"), "cn=yka,dc=ping,dc=cap") } - -func getConnectionsWithinTimeout(t *testing.T, ch chan net.Conn, timeout time.Duration, count int) []net.Conn { - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - result := []net.Conn{} - for count > 0 { - count-- - - select { - case conn := <-ch: - result = append(result, conn) - case <-ctx.Done(): - require.Fail(t, "fail to get connections") - } - } - - return result -} - -func TestLDAPConnectionPool(t *testing.T) { - // allocate a random port between the port range - port := rand.Int()%listenPortRangeLength + listenPortRangeStart - ldapServer := NewMockLDAPServer(t) - ldapServer.Listen(port) - defer ldapServer.Close() - conns := ldapServer.GetConnections() - - impl := &ldapAuthImpl{ldapServerHost: "localhost", ldapServerPort: port} - impl.SetInitCapacity(256) - impl.SetMaxCapacity(1024) - conn, err := impl.getConnection() - require.NoError(t, err) - impl.putConnection(conn) - - getConnectionsWithinTimeout(t, conns, time.Second, 1) - - // test allocating 255 more connections - var clientConnections []*ldap.Conn - for i := 0; i < 256; i++ { - conn, err := impl.getConnection() - require.NoError(t, err) - - clientConnections = append(clientConnections, conn) - } - getConnectionsWithinTimeout(t, conns, time.Second, 255) - for _, conn := range clientConnections { - impl.putConnection(conn) - } - - clientConnections = clientConnections[:] - - // now, the max capacity is somehow meaningless - // TODO: auto scalling the capacity of LDAP connection pool -} diff --git a/privilege/privileges/ldap/mock_ldap_server_test.go b/privilege/privileges/ldap/mock_ldap_server_test.go deleted file mode 100644 index dc65b5b63ee80..0000000000000 --- a/privilege/privileges/ldap/mock_ldap_server_test.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2023 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package ldap - -import ( - "fmt" - "net" - "sync" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/require" -) - -// MockLDAPServer is a mock LDAP server to help testing the LDAP authentication -type MockLDAPServer struct { - t *testing.T - wg sync.WaitGroup - - listener net.Listener - closed atomic.Bool - - connections chan net.Conn -} - -// NewMockLDAPServer creates the MockLDAPServer -func NewMockLDAPServer(t *testing.T) *MockLDAPServer { - return &MockLDAPServer{ - t: t, - connections: make(chan net.Conn), - } -} - -// Listen listens on the specific port -func (s *MockLDAPServer) Listen(port int) { - l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) - require.NoError(s.t, err) - - s.wg.Add(1) - go func() { - for { - conn, err := l.Accept() - if err != nil { - if s.closed.Load() { - break - } - - require.NoError(s.t, err) - } - go func() { - s.connections <- conn - }() - } - s.wg.Done() - }() - - s.listener = l -} - -// Close closes the listener -func (s *MockLDAPServer) Close() { - s.closed.Store(true) - err := s.listener.Close() - require.NoError(s.t, err) - - s.wg.Wait() -} - -// GetConnections returns all live connections -func (s *MockLDAPServer) GetConnections() chan net.Conn { - return s.connections -}