Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: refactor ssh pool implementation (backport #1041) #1042

Merged
merged 1 commit into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions ssh_toolkit/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@ import (
func ExecCommandOverSSH(cmd string,
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
host string, port int, user string, privateKey string, // for ssh client
) error {
return ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, sessionTimeoutSeconds, host, port, user, privateKey, true)
}

func ExecCommandOverSSHWithOptions(cmd string,
stdoutBuf, stderrBuf *bytes.Buffer, sessionTimeoutSeconds int, // for target task
host string, port int, user string, privateKey string, // for ssh client
validate bool, // if true, will validate if server is online
) error {
// fetch ssh client
sshRecord, err := getSSHClient(host, port, user, privateKey)
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, validate)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
Expand Down Expand Up @@ -49,12 +57,9 @@ func ExecCommandOverSSH(cmd string,
// run command
err = session.Run(cmd)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
}
if isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
if isErrorWhenSSHClientNeedToBeRecreated(err) || isErrorWhenSSHClientNeedToBeRecreated(errors.New(stderrBuf.String())) {
DeleteSSHClient(host)
return fmt.Errorf("%s - %s", err, stderrBuf.String())
return fmt.Errorf("%s - %s", err.Error(), stderrBuf.String())
}
return err
}
Expand Down
11 changes: 8 additions & 3 deletions ssh_toolkit/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package ssh_toolkit

import "strings"
import (
"strings"
)

var errorsWhenSSHClientNeedToBeRecreated = []string{
"dial timeout",
Expand All @@ -23,17 +25,20 @@ var errorsWhenSSHClientNeedToBeRecreated = []string{
"open failed",
"handshake failed",
"subsystem request failed",
"EOF",
"eof",
"broken pipe",
"closing write end of pipe",
"connection reset by peer",
"unexpected packet in response to channel open",
}

func isErrorWhenSSHClientNeedToBeRecreated(err error) bool {
if err == nil {
return false
}
errMsg := strings.ToLower(err.Error())
for _, msg := range errorsWhenSSHClientNeedToBeRecreated {
if strings.Contains(err.Error(), msg) {
if strings.Contains(errMsg, msg) {
return true
}
}
Expand Down
2 changes: 1 addition & 1 deletion ssh_toolkit/net_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func NetConnOverSSH(
host string, port int, user string, privateKey string, // for ssh client
) (net.Conn, error) {
// fetch ssh client
sshRecord, err := getSSHClient(host, port, user, privateKey)
sshRecord, err := getSSHClientWithOptions(host, port, user, privateKey, true)
if err != nil {
if isErrorWhenSSHClientNeedToBeRecreated(err) {
DeleteSSHClient(host)
Expand Down
20 changes: 16 additions & 4 deletions ssh_toolkit/pool.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh_toolkit

import (
"errors"
"fmt"
"log"
"sync"
Expand All @@ -13,12 +14,23 @@ var sshClientPool *sshConnectionPool

func init() {
sshClientPool = &sshConnectionPool{
clients: make(map[string]*sshClient),
mutex: &sync.RWMutex{},
clients: make(map[string]*sshClient),
mutex: &sync.RWMutex{},
validator: nil,
}
}

func getSSHClient(host string, port int, user string, privateKey string) (*ssh.Client, error) {
func SetValidator(validator ServerOnlineStatusValidator) {
sshClientPool.mutex.Lock()
defer sshClientPool.mutex.Unlock()
sshClientPool.validator = &validator
}

func getSSHClientWithOptions(host string, port int, user string, privateKey string, validate bool) (*ssh.Client, error) {
// reject if server is offline
if validate && sshClientPool.validator != nil && !(*sshClientPool.validator)(host) {
return nil, errors.New("server is offline, cannot connect to it")
}
sshClientPool.mutex.RLock()
clientEntry, ok := sshClientPool.clients[host]
sshClientPool.mutex.RUnlock()
Expand Down Expand Up @@ -102,7 +114,7 @@ func DeleteSSHClient(host string) {
}
}
clientEntry.mutex.Unlock()
delete(sshClientPool.clients, host)
}
delete(sshClientPool.clients, host)
sshClientPool.mutex.Unlock()
}
10 changes: 7 additions & 3 deletions ssh_toolkit/types.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package ssh_toolkit

import (
"golang.org/x/crypto/ssh"
"sync"

"golang.org/x/crypto/ssh"
)

type sshConnectionPool struct {
clients map[string]*sshClient // map of <host:port> to sshClient
mutex *sync.RWMutex
clients map[string]*sshClient // map of <host:port> to sshClient
mutex *sync.RWMutex
validator *ServerOnlineStatusValidator
}

type ServerOnlineStatusValidator func(host string) bool

type sshClient struct {
client *ssh.Client
mutex *sync.RWMutex
Expand Down
13 changes: 12 additions & 1 deletion swiftwave_service/core/server.operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package core
import (
"errors"
"fmt"
"gorm.io/gorm"
"net"
"time"

"gorm.io/gorm"
)

// CreateServer creates a new server in the database
Expand Down Expand Up @@ -109,6 +110,16 @@ func FetchServerByID(db *gorm.DB, id uint) (*Server, error) {
return &server, err
}

// FetchServerByIP fetches a server by its IP from the database
func FetchServerByIP(db *gorm.DB, ip string) (*Server, error) {
var server Server
err := db.Where("ip = ?", ip).First(&server).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("server not found")
}
return &server, err
}

// FetchServerIDByHostName fetches a server by its hostname from the database
func FetchServerIDByHostName(db *gorm.DB, hostName string) (uint, error) {
var server Server
Expand Down
79 changes: 51 additions & 28 deletions swiftwave_service/cronjob/server_status_monitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@ package cronjob

import (
"bytes"
"strings"
"sync"
"time"

"github.com/swiftwave-org/swiftwave/ssh_toolkit"
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
"strings"
"time"
)

func (m Manager) MonitorServerStatus() {
logger.CronJobLogger.Println("Starting server status monitor [cronjob]")
for {
m.monitorServerStatus()
time.Sleep(1 * time.Minute)
time.Sleep(2 * time.Second)
}
}

Expand All @@ -29,43 +31,64 @@ func (m Manager) monitorServerStatus() {
if len(servers) == 0 {
logger.CronJobLogger.Println("Skipping ! No server found")
return
}

var wg sync.WaitGroup
for _, server := range servers {
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
continue
}
wg.Add(1)
go func(server core.Server) {
defer wg.Done()
m.checkAndUpdateServerStatus(server)
}(server)
}
wg.Wait()
}

func (m Manager) checkAndUpdateServerStatus(server core.Server) {
if m.isServerOnline(server) {
if server.Status != core.ServerOnline {
err := core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online >", server.HostName, err)
} else {
logger.CronJobLogger.Println("Server marked as online >", server.HostName)
}
}
} else {
for _, server := range servers {
if server.Status == core.ServerNeedsSetup || server.Status == core.ServerPreparing {
continue
if server.Status != core.ServerOffline {
err := core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline >", server.HostName, err)
} else {
logger.CronJobLogger.Println("Server marked as offline >", server.HostName)
}
go func(server core.Server) {
if server.Status == core.ServerOffline {
ssh_toolkit.DeleteSSHClient(server.HostName)
}
if m.isServerOnline(server) {
err = core.MarkServerAsOnline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as online > ", server.HostName)
} else {
logger.CronJobLogger.Println("Server marked as online > ", server.HostName)
}
} else {
err = core.MarkServerAsOffline(&m.ServiceManager.DbClient, &server)
if err != nil {
logger.CronJobLoggerError.Println("DB Error : Failed to mark server as offline > ", server.HostName)
} else {
logger.CronJobLogger.Println("Server marked as offline > ", server.HostName)
}
}
}(server)
} else {
logger.CronJobLogger.Println("Server already offline >", server.HostName)
}
}
}

func (m Manager) isServerOnline(server core.Server) bool {
retries := 3 // try for 3 times before giving up
if server.Status == core.ServerOffline {
/**
* If server is offline, try only once
* Else, it will take total 30 seconds (3 retries * 10 seconds of default SSH timeout)
*/
retries = 1
}
// try for 3 times
for i := 0; i < 3; i++ {
for i := 0; i < retries; i++ {
cmd := "echo ok"
stdoutBuf := new(bytes.Buffer)
stderrBuf := new(bytes.Buffer)
err := ssh_toolkit.ExecCommandOverSSH(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey)
err := ssh_toolkit.ExecCommandOverSSHWithOptions(cmd, stdoutBuf, stderrBuf, 3, server.IP, server.SSHPort, server.User, m.Config.SystemConfig.SshPrivateKey, false)
if err != nil {
logger.CronJobLoggerError.Println("Error while checking if server is online", server.HostName, err.Error())
time.Sleep(1 * time.Second)
continue
}
if strings.Compare(strings.TrimSpace(stdoutBuf.String()), "ok") == 0 {
Expand Down
19 changes: 15 additions & 4 deletions swiftwave_service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@ package swiftwave
import (
"context"
"fmt"
"log"
"net/http"
"strings"

"github.com/fatih/color"
"github.com/golang-jwt/jwt/v5"
echojwt "github.com/labstack/echo-jwt/v4"
"github.com/swiftwave-org/swiftwave/ssh_toolkit"
"github.com/swiftwave-org/swiftwave/swiftwave_service/config"
"github.com/swiftwave-org/swiftwave/swiftwave_service/console"
"github.com/swiftwave-org/swiftwave/swiftwave_service/core"
"github.com/swiftwave-org/swiftwave/swiftwave_service/dashboard"
"github.com/swiftwave-org/swiftwave/swiftwave_service/logger"
"github.com/swiftwave-org/swiftwave/swiftwave_service/service_manager"
"log"
"net/http"
"strings"

"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
Expand All @@ -32,6 +34,15 @@ func StartSwiftwave(config *config.Config) {
}
manager.Load(*config)

// Set the server status validator for ssh
ssh_toolkit.SetValidator(func(host string) bool {
server, err := core.FetchServerByIP(&manager.DbClient, host)
if err != nil {
return false
}
return server.Status != core.ServerOffline
})

// Create pubsub default topics
err := manager.PubSubClient.CreateTopic(manager.CancelImageBuildTopic)
if err != nil {
Expand All @@ -50,7 +61,7 @@ func StartSwiftwave(config *config.Config) {
cronjobManager.Start(true)

// create a channel to block the main thread
var waitForever chan struct{}
waitForever := make(chan struct{})

// StartSwiftwave the swift wave server
go StartServer(config, manager, workerManager)
Expand Down
Loading