Skip to content

Commit

Permalink
Add timeout for SSH handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
NI committed Aug 19, 2019
1 parent 3307662 commit 1070c2b
Showing 1 changed file with 84 additions and 13 deletions.
97 changes: 84 additions & 13 deletions application/commands/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"io"
"net"
"sync"
"time"

"golang.org/x/crypto/ssh"

Expand Down Expand Up @@ -101,6 +102,57 @@ var (
"Unknown client signal")
)

var (
sshEmptyTime = time.Time{}
)

type sshRemoteConnWrapper struct {
net.Conn

readTimeout time.Duration
enableTimeout bool
}

func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error {
s.enableTimeout = false

return s.Conn.SetReadDeadline(t)
}

func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error {
s.enableTimeout = false

return s.Conn.SetWriteDeadline(t)
}

func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error {
s.enableTimeout = false

return s.Conn.SetDeadline(t)
}

func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) {
rLen, rErr := s.Conn.Read(b)

if rErr == nil {
return rLen, nil
}

if !s.enableTimeout {
return rLen, rErr
}

netErr, isNetErr := rErr.(net.Error)

if !isNetErr || !netErr.Timeout() {
return rLen, rErr
}

s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout))

return s.Conn.Read(b)
}

type sshRemoteConn struct {
writer io.Writer
closer func() error
Expand Down Expand Up @@ -295,20 +347,36 @@ func (d *sshClient) comfirmRemoteFingerprint(
}

func (d *sshClient) dialRemote(
network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
network,
addr string,
config *ssh.ClientConfig) (*ssh.Client, func(), error) {
conn, err := d.cfg.Dial(network, addr, config.Timeout)

if err != nil {
return nil, err
return nil, nil, err
}

c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
conn.SetReadDeadline(time.Now().Add(config.Timeout))

sshConn := sshRemoteConnWrapper{
Conn: conn,
readTimeout: config.Timeout,
enableTimeout: true,
}

c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config)

if err != nil {
return nil, err
return nil, nil, err
}

return ssh.NewClient(c, chans, reqs), nil
return ssh.NewClient(c, chans, reqs), func() {
if sshConn.enableTimeout {
sshConn.SetReadDeadline(sshEmptyTime)
}

sshConn.enableTimeout = false
}, nil
}

func (d *sshClient) remote(
Expand All @@ -322,14 +390,15 @@ func (d *sshClient) remote(

buf := [4096]byte{}

conn, dErr := d.dialRemote("tcp", address, &ssh.ClientConfig{
User: user,
Auth: authMethodBuilder(buf[:]),
HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error {
return d.comfirmRemoteFingerprint(h, r, k, buf[:])
},
Timeout: d.cfg.DialTimeout,
})
conn, clearConnInitialDeadline, dErr :=
d.dialRemote("tcp", address, &ssh.ClientConfig{
User: user,
Auth: authMethodBuilder(buf[:]),
HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error {
return d.comfirmRemoteFingerprint(h, r, k, buf[:])
},
Timeout: d.cfg.DialTimeout,
})

if dErr != nil {
errLen := copy(buf[d.w.HeaderSize():], dErr.Error()) + d.w.HeaderSize()
Expand Down Expand Up @@ -418,6 +487,8 @@ func (d *sshClient) remote(

defer session.Wait()

clearConnInitialDeadline()

d.remoteConnReceive <- sshRemoteConn{
writer: in,
closer: func() error {
Expand Down

0 comments on commit 1070c2b

Please sign in to comment.