From 2a9c07269642773206659e8d6bbe2f14c7865f9c Mon Sep 17 00:00:00 2001 From: Tanmoy Sarkar <57363826+tanmoysrt@users.noreply.github.com> Date: Sun, 29 Sep 2024 00:47:48 +0530 Subject: [PATCH] chore: refactor ssh pool implementation (#1041) * chore: modularize the MonitorServerStatus function * feat: no need to delete ssh client if server is already offline - because ssh check function already done the job * feat: reduce delete ssh client call * feat: optimize the server status fixing algo * feat: pre-validation for ssh has been added * chore: reduce complexity (cherry picked from commit 38e6664a732064e56fcff5fdbd52387d4b07a6e7) --- ssh_toolkit/command.go | 17 ++-- ssh_toolkit/errors.go | 11 ++- ssh_toolkit/net_conn.go | 2 +- ssh_toolkit/pool.go | 20 ++++- ssh_toolkit/types.go | 10 ++- swiftwave_service/core/server.operations.go | 13 ++- .../cronjob/server_status_monitor.go | 79 ++++++++++++------- swiftwave_service/main.go | 19 ++++- 8 files changed, 121 insertions(+), 50 deletions(-) diff --git a/ssh_toolkit/command.go b/ssh_toolkit/command.go index ef49947657..0c9358a29e 100644 --- a/ssh_toolkit/command.go +++ b/ssh_toolkit/command.go @@ -14,9 +14,17 @@ import ( func ExecCommandOverSSH(cmd string, stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task host string, port int, user string, privateKey string, // for ssh client +) error { + return ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, sessionTimeoutSeconds, host, port, user, privateKey, true) +} + +func ExecCommandOverSSHWithOptions(cmd string, + stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task + host string, port int, user string, privateKey string, // for ssh client + validate bool, // if true, will validate if server is online ) error { // fetch ssh client - sshRecord, err := getSSHClient(host, port, user, privateKey) + sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, validate) if err != nil { if isErrorWhenSSHClientNeedToBeRecreated(err) { DeleteSSHClient(host) @@ -49,12 +57,9 @@ func ExecCommandOverSSH(cmd string, // run command err = session.Run(cmd) if err != nil { - if isErrorWhenSSHClientNeedToBeRecreated(err) { - DeleteSSHClient(host) - } - if isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) { + if isErrorWhenSSHClientNeedToBeRecreated(err) || isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) { DeleteSSHClient(host) - return fmt.Errorf("%s - %s", err, stderrBuf.String()) + return fmt.Errorf("%s - %s", err.Error(), stderrBuf.String()) } return err } diff --git a/ssh_toolkit/errors.go b/ssh_toolkit/errors.go index 976e9ebf7c..3733fe992e 100644 --- a/ssh_toolkit/errors.go +++ b/ssh_toolkit/errors.go @@ -1,6 +1,8 @@ package ssh_toolkit -import "strings" +import ( + "strings" +) var errorsWhenSSHClientNeedToBeRecreated = []string{ "dial timeout", @@ -23,17 +25,20 @@ var errorsWhenSSHClientNeedToBeRecreated = []string{ "open failed", "handshake failed", "subsystem request failed", - "EOF", + "eof", "broken pipe", "closing write end of pipe", + "connection reset by peer", + "unexpected packet in response to channel open", } func isErrorWhenSSHClientNeedToBeRecreated(err error) bool { if err == nil { return false } + errMsg := strings.ToLower(err.Error()) for _, msg := range errorsWhenSSHClientNeedToBeRecreated { - if strings.Contains(err.Error(), msg) { + if strings.Contains(errMsg, msg) { return true } } diff --git a/ssh_toolkit/net_conn.go b/ssh_toolkit/net_conn.go index eddc153ef6..bc5b6ab638 100644 --- a/ssh_toolkit/net_conn.go +++ b/ssh_toolkit/net_conn.go @@ -12,7 +12,7 @@ func NetConnOverSSH( host string, port int, user string, privateKey string, // for ssh client ) (net.Conn, error) { // fetch ssh client - sshRecord, err := getSSHClient(host, port, user, privateKey) + sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, true) if err != nil { if isErrorWhenSSHClientNeedToBeRecreated(err) { DeleteSSHClient(host) diff --git a/ssh_toolkit/pool.go b/ssh_toolkit/pool.go index bb876f1aec..f6ec1ed7e3 100644 --- a/ssh_toolkit/pool.go +++ b/ssh_toolkit/pool.go @@ -1,6 +1,7 @@ package ssh_toolkit import ( + "errors" "fmt" "log" "sync" @@ -13,12 +14,23 @@ var sshClientPool *sshConnectionPool func init() { sshClientPool = &sshConnectionPool{ - clients: make(map[string]*sshClient), - mutex: &sync.RWMutex{}, + clients: make(map[string]*sshClient), + mutex: &sync.RWMutex{}, + validator: nil, } } -func getSSHClient(host string, port int, user string, privateKey string) (*ssh.Client, error) { +func SetValidator(validator ServerOnlineStatusValidator) { + sshClientPool.mutex.Lock() + defer sshClientPool.mutex.Unlock() + sshClientPool.validator = &validator +} + +func getSSHClientWithOptions(host string, port int, user string, privateKey string, validate bool) (*ssh.Client, error) { + // reject if server is offline + if validate && sshClientPool.validator != nil && !(*sshClientPool.validator)(host) { + return nil, errors.New("server is offline, cannot connect to it") + } sshClientPool.mutex.RLock() clientEntry, ok := sshClientPool.clients[host] sshClientPool.mutex.RUnlock() @@ -102,7 +114,7 @@ func DeleteSSHClient(host string) { } } clientEntry.mutex.Unlock() + delete(sshClientPool.clients, host) } - delete(sshClientPool.clients, host) sshClientPool.mutex.Unlock() } diff --git a/ssh_toolkit/types.go b/ssh_toolkit/types.go index 392ed22a38..6b059a8e1d 100644 --- a/ssh_toolkit/types.go +++ b/ssh_toolkit/types.go @@ -1,15 +1,19 @@ package ssh_toolkit import ( - "golang.org/x/crypto/ssh" "sync" + + "golang.org/x/crypto/ssh" ) type sshConnectionPool struct { - clients map[string]*sshClient // map of to sshClient - mutex *sync.RWMutex + clients map[string]*sshClient // map of to sshClient + mutex *sync.RWMutex + validator *ServerOnlineStatusValidator } +type ServerOnlineStatusValidator func(host string) bool + type sshClient struct { client *ssh.Client mutex *sync.RWMutex diff --git a/swiftwave_service/core/server.operations.go b/swiftwave_service/core/server.operations.go index 94f3602a84..f95efaecb5 100644 --- a/swiftwave_service/core/server.operations.go +++ b/swiftwave_service/core/server.operations.go @@ -3,9 +3,10 @@ package core import ( "errors" "fmt" - "gorm.io/gorm" "net" "time" + + "gorm.io/gorm" ) // CreateServer creates a new server in the database @@ -109,6 +110,16 @@ func FetchServerByID(db *gorm.DB, id uint) (*Server, error) { return &server, err } +// FetchServerByIP fetches a server by its IP from the database +func FetchServerByIP(db *gorm.DB, ip string) (*Server, error) { + var server Server + err := db.Where("ip = ?", ip).First(&server).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("server not found") + } + return &server, err +} + // FetchServerIDByHostName fetches a server by its hostname from the database func FetchServerIDByHostName(db *gorm.DB, hostName string) (uint, error) { var server Server diff --git a/swiftwave_service/cronjob/server_status_monitor.go b/swiftwave_service/cronjob/server_status_monitor.go index e3ba268ef2..c7e22773c5 100644 --- a/swiftwave_service/cronjob/server_status_monitor.go +++ b/swiftwave_service/cronjob/server_status_monitor.go @@ -2,18 +2,20 @@ package cronjob import ( "bytes" + "strings" + "sync" + "time" + "github.com/swiftwave-org/swiftwave/ssh_toolkit" "github.com/swiftwave-org/swiftwave/swiftwave_service/core" "github.com/swiftwave-org/swiftwave/swiftwave_service/logger" - "strings" - "time" ) func (m Manager) MonitorServerStatus() { logger.CronJobLogger.Println("Starting server status monitor [cronjob]") for { m.monitorServerStatus() - time.Sleep(1 * time.Minute) + time.Sleep(2 * time.Second) } } @@ -29,43 +31,64 @@ func (m Manager) monitorServerStatus() { if len(servers) == 0 { logger.CronJobLogger.Println("Skipping ! No server found") return + } + + var wg sync.WaitGroup + for _, server := range servers { + if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing { + continue + } + wg.Add(1) + go func(server core.Server) { + defer wg.Done() + m.checkAndUpdateServerStatus(server) + }(server) + } + wg.Wait() +} + +func (m Manager) checkAndUpdateServerStatus(server core.Server) { + if m.isServerOnline(server) { + if server.Status != core.ServerOnline { + err := core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server) + if err != nil { + logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online >", server.HostName, err) + } else { + logger.CronJobLogger.Println("Server marked as online >", server.HostName) + } + } } else { - for _, server := range servers { - if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing { - continue + if server.Status != core.ServerOffline { + err := core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server) + if err != nil { + logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline >", server.HostName, err) + } else { + logger.CronJobLogger.Println("Server marked as offline >", server.HostName) } - go func(server core.Server) { - if server.Status == core.ServerOffline { - ssh_toolkit.DeleteSSHClient(server.HostName) - } - if m.isServerOnline(server) { - err = core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server) - if err != nil { - logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online > ", server.HostName) - } else { - logger.CronJobLogger.Println("Server marked as online > ", server.HostName) - } - } else { - err = core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server) - if err != nil { - logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline > ", server.HostName) - } else { - logger.CronJobLogger.Println("Server marked as offline > ", server.HostName) - } - } - }(server) + } else { + logger.CronJobLogger.Println("Server already offline >", server.HostName) } } } func (m Manager) isServerOnline(server core.Server) bool { + retries := 3 // try for 3 times before giving up + if server.Status == core.ServerOffline { + /** + * If server is offline, try only once + * Else, it will take total 30 seconds (3 retries * 10 seconds of default SSH timeout) + */ + retries = 1 + } // try for 3 times - for i := 0; i < 3; i++ { + for i := 0; i < retries; i++ { cmd := "echo ok" stdoutBuf := new(bytes.Buffer) stderrBuf := new(bytes.Buffer) - err := ssh_toolkit.ExecCommandOverSSH(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey) + err := ssh_toolkit.ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey, false) if err != nil { + logger.CronJobLoggerError.Println("Error while checking if server is online", server.HostName, err.Error()) + time.Sleep(1 * time.Second) continue } if strings.Compare(strings.TrimSpace(stdoutBuf.String()), "ok") == 0 { diff --git a/swiftwave_service/main.go b/swiftwave_service/main.go index 37ccb65b64..6fe9c71e03 100644 --- a/swiftwave_service/main.go +++ b/swiftwave_service/main.go @@ -3,18 +3,20 @@ package swiftwave import ( "context" "fmt" + "log" + "net/http" + "strings" + "github.com/fatih/color" "github.com/golang-jwt/jwt/v5" echojwt "github.com/labstack/echo-jwt/v4" + "github.com/swiftwave-org/swiftwave/ssh_toolkit" "github.com/swiftwave-org/swiftwave/swiftwave_service/config" "github.com/swiftwave-org/swiftwave/swiftwave_service/console" "github.com/swiftwave-org/swiftwave/swiftwave_service/core" "github.com/swiftwave-org/swiftwave/swiftwave_service/dashboard" "github.com/swiftwave-org/swiftwave/swiftwave_service/logger" "github.com/swiftwave-org/swiftwave/swiftwave_service/service_manager" - "log" - "net/http" - "strings" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -32,6 +34,15 @@ func StartSwiftwave(config *config.Config) { } manager.Load(*config) + // Set the server status validator for ssh + ssh_toolkit.SetValidator(func(host string) bool { + server, err := core.FetchServerByIP(&manager.DbClient, host) + if err != nil { + return false + } + return server.Status != core.ServerOffline + }) + // Create pubsub default topics err := manager.PubSubClient.CreateTopic(manager.CancelImageBuildTopic) if err != nil { @@ -50,7 +61,7 @@ func StartSwiftwave(config *config.Config) { cronjobManager.Start(true) // create a channel to block the main thread - var waitForever chan struct{} + waitForever := make(chan struct{}) // StartSwiftwave the swift wave server go StartServer(config, manager, workerManager)