Skip to content

Commit e34ea6c

Browse files
committed
refactor: Simplify SSH dialer implementation and improve error handling
1 parent 1379fec commit e34ea6c

File tree

1 file changed

+26
-35
lines changed

1 file changed

+26
-35
lines changed

pkg/sshutils/ssh_dial.go

+26-35
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package sshutils
22

33
import (
44
"context"
5+
"fmt"
56

67
"github.com/stretchr/testify/mock"
78
"golang.org/x/crypto/ssh"
@@ -20,45 +21,43 @@ type SSHDialer interface {
2021
) (SSHClienter, error)
2122
}
2223

23-
func NewSSHDial(host string, port int, config *ssh.ClientConfig) SSHDialer {
24-
return &SSHDial{
25-
DialCreator: func(network, addr string, config *ssh.ClientConfig) (SSHClienter, error) {
26-
client, err := ssh.Dial(network, addr, config)
27-
if err != nil {
28-
return nil, err
29-
}
30-
return &SSHClientWrapper{Client: client}, nil
31-
},
32-
}
24+
type sshDial struct {
25+
host string
26+
port int
27+
config *ssh.ClientConfig
3328
}
3429

35-
type SSHDial struct {
36-
DialCreator func(network, addr string, config *ssh.ClientConfig) (SSHClienter, error)
30+
func NewSSHDial(host string, port int, config *ssh.ClientConfig) SSHDialer {
31+
return &sshDial{
32+
host: host,
33+
port: port,
34+
config: config,
35+
}
3736
}
3837

39-
func (d *SSHDial) Dial(network, addr string, config *ssh.ClientConfig) (SSHClienter, error) {
40-
return d.DialCreator(network, addr, config)
38+
func (s *sshDial) Dial(network, addr string, config *ssh.ClientConfig) (SSHClienter, error) {
39+
client, err := ssh.Dial(network, addr, config)
40+
if err != nil {
41+
return nil, fmt.Errorf("failed to dial: %w", err)
42+
}
43+
return &SSHClientWrapper{Client: client}, nil
4144
}
4245

43-
func (d *SSHDial) DialContext(
44-
ctx context.Context,
45-
network, addr string,
46-
config *ssh.ClientConfig,
47-
) (SSHClienter, error) {
48-
// Create a channel to receive the dial result
46+
func (s *sshDial) DialContext(ctx context.Context, network, addr string, config *ssh.ClientConfig) (SSHClienter, error) {
4947
type dialResult struct {
5048
client SSHClienter
5149
err error
5250
}
51+
5352
result := make(chan dialResult, 1)
5453

5554
// Start dialing in a goroutine
5655
go func() {
57-
client, err := d.Dial(network, addr, config)
56+
client, err := s.Dial(network, addr, config)
5857
result <- dialResult{client, err}
5958
}()
6059

61-
// Wait for either context cancellation or dial completion
60+
// Wait for either dial completion or context cancellation
6261
select {
6362
case <-ctx.Done():
6463
return nil, ctx.Err()
@@ -69,26 +68,18 @@ func (d *SSHDial) DialContext(
6968

7069
// Mock Functions
7170

72-
// MockSSHDialer is a mock implementation of SSHDialer
71+
// MockSSHDialer is a mock implementation of SSHDialer for testing
7372
type MockSSHDialer struct {
74-
mock.Mock
75-
}
76-
77-
// MockSSHClient is a mock implementation of SSHClienter
78-
type MockSFTPClient struct {
79-
mock.Mock
73+
DialFunc func(network, addr string, config *ssh.ClientConfig) (SSHClienter, error)
8074
}
8175

82-
// Dial is a mock implementation of the Dial method
8376
func (m *MockSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (SSHClienter, error) {
84-
args := m.Called(network, addr, config)
85-
if args.Get(1) != nil {
86-
return nil, args.Error(1)
77+
if m.DialFunc != nil {
78+
return m.DialFunc(network, addr, config)
8779
}
88-
return args.Get(0).(SSHClienter), nil
80+
return nil, nil
8981
}
9082

91-
// DialContext is a mock implementation of the DialContext method
9283
func (m *MockSSHDialer) DialContext(
9384
ctx context.Context,
9485
network, addr string,

0 commit comments

Comments
 (0)