From 1070c2bcf2e255d104b348c0244aa099921fb30b Mon Sep 17 00:00:00 2001 From: NI Date: Mon, 19 Aug 2019 21:29:47 +0800 Subject: [PATCH] Add timeout for SSH handshake --- application/commands/ssh.go | 97 ++++++++++++++++++++++++++++++++----- 1 file changed, 84 insertions(+), 13 deletions(-) diff --git a/application/commands/ssh.go b/application/commands/ssh.go index 14ed4717..e482f8ea 100644 --- a/application/commands/ssh.go +++ b/application/commands/ssh.go @@ -22,6 +22,7 @@ import ( "io" "net" "sync" + "time" "golang.org/x/crypto/ssh" @@ -101,6 +102,57 @@ var ( "Unknown client signal") ) +var ( + sshEmptyTime = time.Time{} +) + +type sshRemoteConnWrapper struct { + net.Conn + + readTimeout time.Duration + enableTimeout bool +} + +func (s *sshRemoteConnWrapper) SetReadDeadline(t time.Time) error { + s.enableTimeout = false + + return s.Conn.SetReadDeadline(t) +} + +func (s *sshRemoteConnWrapper) SetWriteDeadline(t time.Time) error { + s.enableTimeout = false + + return s.Conn.SetWriteDeadline(t) +} + +func (s *sshRemoteConnWrapper) SetDeadline(t time.Time) error { + s.enableTimeout = false + + return s.Conn.SetDeadline(t) +} + +func (s *sshRemoteConnWrapper) Read(b []byte) (int, error) { + rLen, rErr := s.Conn.Read(b) + + if rErr == nil { + return rLen, nil + } + + if !s.enableTimeout { + return rLen, rErr + } + + netErr, isNetErr := rErr.(net.Error) + + if !isNetErr || !netErr.Timeout() { + return rLen, rErr + } + + s.Conn.SetReadDeadline(time.Now().Add(s.readTimeout)) + + return s.Conn.Read(b) +} + type sshRemoteConn struct { writer io.Writer closer func() error @@ -295,20 +347,36 @@ func (d *sshClient) comfirmRemoteFingerprint( } func (d *sshClient) dialRemote( - network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + network, + addr string, + config *ssh.ClientConfig) (*ssh.Client, func(), error) { conn, err := d.cfg.Dial(network, addr, config.Timeout) if err != nil { - return nil, err + return nil, nil, err } - c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + conn.SetReadDeadline(time.Now().Add(config.Timeout)) + + sshConn := sshRemoteConnWrapper{ + Conn: conn, + readTimeout: config.Timeout, + enableTimeout: true, + } + + c, chans, reqs, err := ssh.NewClientConn(&sshConn, addr, config) if err != nil { - return nil, err + return nil, nil, err } - return ssh.NewClient(c, chans, reqs), nil + return ssh.NewClient(c, chans, reqs), func() { + if sshConn.enableTimeout { + sshConn.SetReadDeadline(sshEmptyTime) + } + + sshConn.enableTimeout = false + }, nil } func (d *sshClient) remote( @@ -322,14 +390,15 @@ func (d *sshClient) remote( buf := [4096]byte{} - conn, dErr := d.dialRemote("tcp", address, &ssh.ClientConfig{ - User: user, - Auth: authMethodBuilder(buf[:]), - HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error { - return d.comfirmRemoteFingerprint(h, r, k, buf[:]) - }, - Timeout: d.cfg.DialTimeout, - }) + conn, clearConnInitialDeadline, dErr := + d.dialRemote("tcp", address, &ssh.ClientConfig{ + User: user, + Auth: authMethodBuilder(buf[:]), + HostKeyCallback: func(h string, r net.Addr, k ssh.PublicKey) error { + return d.comfirmRemoteFingerprint(h, r, k, buf[:]) + }, + Timeout: d.cfg.DialTimeout, + }) if dErr != nil { errLen := copy(buf[d.w.HeaderSize():], dErr.Error()) + d.w.HeaderSize() @@ -418,6 +487,8 @@ func (d *sshClient) remote( defer session.Wait() + clearConnInitialDeadline() + d.remoteConnReceive <- sshRemoteConn{ writer: in, closer: func() error {