Skip to content

Commit

Permalink
chore: refactor ssh pool implementation (#1041) (#1042)
Browse files Browse the repository at this point in the history
* chore: modularize the MonitorServerStatus function

* feat: no need to delete ssh client if server is already offline
- because ssh check function already done the job

* feat: reduce delete ssh client call

* feat: optimize the server status fixing algo

* feat: pre-validation for ssh has been added

* chore: reduce complexity

(cherry picked from commit 38e6664)

Co-authored-by: Tanmoy Sarkar <[email protected]>
  • Loading branch information
mergify[bot] and tanmoysrt authored Sep 28, 2024
1 parent a360b95 commit ff43f86
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 50 deletions.
17 changes: 11 additions & 6 deletions ssh_toolkit/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@ import (
func ExecCommandOverSSH(cmd string,
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
host string, port int, user string, privateKey string, // for ssh client
) error {
return ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, sessionTimeoutSeconds, host, port, user, privateKey, true)
}

func ExecCommandOverSSHWithOptions(cmd string,
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
host string, port int, user string, privateKey string, // for ssh client
validate bool, // if true, will validate if server is online
) error {
// fetch ssh client
sshRecord, err := getSSHClient(host, port, user, privateKey)
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, validate)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
Expand Down Expand Up @@ -49,12 +57,9 @@ func ExecCommandOverSSH(cmd string,
// run command
err = session.Run(cmd)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
}
if isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
if isErrorWhenSSHClientNeedToBeRecreated(err) || isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
DeleteSSHClient(host)
return fmt.Errorf("%s - %s", err, stderrBuf.String())
return fmt.Errorf("%s - %s", err.Error(), stderrBuf.String())
}
return err
}
Expand Down
11 changes: 8 additions & 3 deletions ssh_toolkit/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ssh_toolkit

import "strings"
import (
"strings"
)

var errorsWhenSSHClientNeedToBeRecreated = []string{
"dial timeout",
Expand All @@ -23,17 +25,20 @@ var errorsWhenSSHClientNeedToBeRecreated = []string{
"open failed",
"handshake failed",
"subsystem request failed",
"EOF",
"eof",
"broken pipe",
"closing write end of pipe",
"connection reset by peer",
"unexpected packet in response to channel open",
}

func isErrorWhenSSHClientNeedToBeRecreated(err error) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
for _, msg := range errorsWhenSSHClientNeedToBeRecreated {
if strings.Contains(err.Error(), msg) {
if strings.Contains(errMsg, msg) {
return true
}
}
Expand Down
2 changes: 1 addition & 1 deletion ssh_toolkit/net_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func NetConnOverSSH(
host string, port int, user string, privateKey string, // for ssh client
) (net.Conn, error) {
// fetch ssh client
sshRecord, err := getSSHClient(host, port, user, privateKey)
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, true)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
Expand Down
20 changes: 16 additions & 4 deletions ssh_toolkit/pool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh_toolkit

import (
"errors"
"fmt"
"log"
"sync"
Expand All @@ -13,12 +14,23 @@ var sshClientPool *sshConnectionPool

func init() {
sshClientPool = &sshConnectionPool{
clients: make(map[string]*sshClient),
mutex: &sync.RWMutex{},
clients: make(map[string]*sshClient),
mutex: &sync.RWMutex{},
validator: nil,
}
}

func getSSHClient(host string, port int, user string, privateKey string) (*ssh.Client, error) {
func SetValidator(validator ServerOnlineStatusValidator) {
sshClientPool.mutex.Lock()
defer sshClientPool.mutex.Unlock()
sshClientPool.validator = &validator
}

func getSSHClientWithOptions(host string, port int, user string, privateKey string, validate bool) (*ssh.Client, error) {
// reject if server is offline
if validate && sshClientPool.validator != nil && !(*sshClientPool.validator)(host) {
return nil, errors.New("server is offline, cannot connect to it")
}
sshClientPool.mutex.RLock()
clientEntry, ok := sshClientPool.clients[host]
sshClientPool.mutex.RUnlock()
Expand Down Expand Up @@ -102,7 +114,7 @@ func DeleteSSHClient(host string) {
}
}
clientEntry.mutex.Unlock()
delete(sshClientPool.clients, host)
}
delete(sshClientPool.clients, host)
sshClientPool.mutex.Unlock()
}
10 changes: 7 additions & 3 deletions ssh_toolkit/types.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package ssh_toolkit

import (
"golang.org/x/crypto/ssh"
"sync"

"golang.org/x/crypto/ssh"
)

type sshConnectionPool struct {
clients map[string]*sshClient // map of <host:port> to sshClient
mutex *sync.RWMutex
clients map[string]*sshClient // map of <host:port> to sshClient
mutex *sync.RWMutex
validator *ServerOnlineStatusValidator
}

type ServerOnlineStatusValidator func(host string) bool

type sshClient struct {
client *ssh.Client
mutex *sync.RWMutex
Expand Down
13 changes: 12 additions & 1 deletion swiftwave_service/core/server.operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package core
import (
"errors"
"fmt"
"gorm.io/gorm"
"net"
"time"

"gorm.io/gorm"
)

// CreateServer creates a new server in the database
Expand Down Expand Up @@ -109,6 +110,16 @@ func FetchServerByID(db *gorm.DB, id uint) (*Server, error) {
return &server, err
}

// FetchServerByIP fetches a server by its IP from the database
func FetchServerByIP(db *gorm.DB, ip string) (*Server, error) {
var server Server
err := db.Where("ip = ?", ip).First(&server).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("server not found")
}
return &server, err
}

// FetchServerIDByHostName fetches a server by its hostname from the database
func FetchServerIDByHostName(db *gorm.DB, hostName string) (uint, error) {
var server Server
Expand Down
79 changes: 51 additions & 28 deletions swiftwave_service/cronjob/server_status_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package cronjob

import (
"bytes"
"strings"
"sync"
"time"

"github.com/swiftwave-org/swiftwave/ssh_toolkit"
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
"strings"
"time"
)

func (m Manager) MonitorServerStatus() {
logger.CronJobLogger.Println("Starting server status monitor [cronjob]")
for {
m.monitorServerStatus()
time.Sleep(1 * time.Minute)
time.Sleep(2 * time.Second)
}
}

Expand All @@ -29,43 +31,64 @@ func (m Manager) monitorServerStatus() {
if len(servers) == 0 {
logger.CronJobLogger.Println("Skipping ! No server found")
return
}

var wg sync.WaitGroup
for _, server := range servers {
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
continue
}
wg.Add(1)
go func(server core.Server) {
defer wg.Done()
m.checkAndUpdateServerStatus(server)
}(server)
}
wg.Wait()
}

func (m Manager) checkAndUpdateServerStatus(server core.Server) {
if m.isServerOnline(server) {
if server.Status != core.ServerOnline {
err := core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online >", server.HostName, err)
} else {
logger.CronJobLogger.Println("Server marked as online >", server.HostName)
}
}
} else {
for _, server := range servers {
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
continue
if server.Status != core.ServerOffline {
err := core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline >", server.HostName, err)
} else {
logger.CronJobLogger.Println("Server marked as offline >", server.HostName)
}
go func(server core.Server) {
if server.Status == core.ServerOffline {
ssh_toolkit.DeleteSSHClient(server.HostName)
}
if m.isServerOnline(server) {
err = core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online > ", server.HostName)
} else {
logger.CronJobLogger.Println("Server marked as online > ", server.HostName)
}
} else {
err = core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline > ", server.HostName)
} else {
logger.CronJobLogger.Println("Server marked as offline > ", server.HostName)
}
}
}(server)
} else {
logger.CronJobLogger.Println("Server already offline >", server.HostName)
}
}
}

func (m Manager) isServerOnline(server core.Server) bool {
retries := 3 // try for 3 times before giving up
if server.Status == core.ServerOffline {
/**
* If server is offline, try only once
* Else, it will take total 30 seconds (3 retries * 10 seconds of default SSH timeout)
*/
retries = 1
}
// try for 3 times
for i := 0; i < 3; i++ {
for i := 0; i < retries; i++ {
cmd := "echo ok"
stdoutBuf := new(bytes.Buffer)
stderrBuf := new(bytes.Buffer)
err := ssh_toolkit.ExecCommandOverSSH(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey)
err := ssh_toolkit.ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey, false)
if err != nil {
logger.CronJobLoggerError.Println("Error while checking if server is online", server.HostName, err.Error())
time.Sleep(1 * time.Second)
continue
}
if strings.Compare(strings.TrimSpace(stdoutBuf.String()), "ok") == 0 {
Expand Down
19 changes: 15 additions & 4 deletions swiftwave_service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ package swiftwave
import (
"context"
"fmt"
"log"
"net/http"
"strings"

"github.com/fatih/color"
"github.com/golang-jwt/jwt/v5"
echojwt "github.com/labstack/echo-jwt/v4"
"github.com/swiftwave-org/swiftwave/ssh_toolkit"
"github.com/swiftwave-org/swiftwave/swiftwave_service/config"
"github.com/swiftwave-org/swiftwave/swiftwave_service/console"
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
"github.com/swiftwave-org/swiftwave/swiftwave_service/dashboard"
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
"github.com/swiftwave-org/swiftwave/swiftwave_service/service_manager"
"log"
"net/http"
"strings"

"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
Expand All @@ -32,6 +34,15 @@ func StartSwiftwave(config *config.Config) {
}
manager.Load(*config)

// Set the server status validator for ssh
ssh_toolkit.SetValidator(func(host string) bool {
server, err := core.FetchServerByIP(&manager.DbClient, host)
if err != nil {
return false
}
return server.Status != core.ServerOffline
})

// Create pubsub default topics
err := manager.PubSubClient.CreateTopic(manager.CancelImageBuildTopic)
if err != nil {
Expand All @@ -50,7 +61,7 @@ func StartSwiftwave(config *config.Config) {
cronjobManager.Start(true)

// create a channel to block the main thread
var waitForever chan struct{}
waitForever := make(chan struct{})

// StartSwiftwave the swift wave server
go StartServer(config, manager, workerManager)
Expand Down

0 comments on commit ff43f86

Please sign in to comment.