@@ -2,6 +2,7 @@ package sshutils
2
2
3
3
import (
4
4
"context"
5
+ "fmt"
5
6
6
7
"github.com/stretchr/testify/mock"
7
8
"golang.org/x/crypto/ssh"
@@ -20,45 +21,43 @@ type SSHDialer interface {
20
21
) (SSHClienter , error )
21
22
}
22
23
23
- func NewSSHDial (host string , port int , config * ssh.ClientConfig ) SSHDialer {
24
- return & SSHDial {
25
- DialCreator : func (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error ) {
26
- client , err := ssh .Dial (network , addr , config )
27
- if err != nil {
28
- return nil , err
29
- }
30
- return & SSHClientWrapper {Client : client }, nil
31
- },
32
- }
24
+ type sshDial struct {
25
+ host string
26
+ port int
27
+ config * ssh.ClientConfig
33
28
}
34
29
35
- type SSHDial struct {
36
- DialCreator func (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error )
30
+ func NewSSHDial (host string , port int , config * ssh.ClientConfig ) SSHDialer {
31
+ return & sshDial {
32
+ host : host ,
33
+ port : port ,
34
+ config : config ,
35
+ }
37
36
}
38
37
39
- func (d * SSHDial ) Dial (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error ) {
40
- return d .DialCreator (network , addr , config )
38
+ func (s * sshDial ) Dial (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error ) {
39
+ client , err := ssh .Dial (network , addr , config )
40
+ if err != nil {
41
+ return nil , fmt .Errorf ("failed to dial: %w" , err )
42
+ }
43
+ return & SSHClientWrapper {Client : client }, nil
41
44
}
42
45
43
- func (d * SSHDial ) DialContext (
44
- ctx context.Context ,
45
- network , addr string ,
46
- config * ssh.ClientConfig ,
47
- ) (SSHClienter , error ) {
48
- // Create a channel to receive the dial result
46
+ func (s * sshDial ) DialContext (ctx context.Context , network , addr string , config * ssh.ClientConfig ) (SSHClienter , error ) {
49
47
type dialResult struct {
50
48
client SSHClienter
51
49
err error
52
50
}
51
+
53
52
result := make (chan dialResult , 1 )
54
53
55
54
// Start dialing in a goroutine
56
55
go func () {
57
- client , err := d .Dial (network , addr , config )
56
+ client , err := s .Dial (network , addr , config )
58
57
result <- dialResult {client , err }
59
58
}()
60
59
61
- // Wait for either context cancellation or dial completion
60
+ // Wait for either dial completion or context cancellation
62
61
select {
63
62
case <- ctx .Done ():
64
63
return nil , ctx .Err ()
@@ -69,26 +68,18 @@ func (d *SSHDial) DialContext(
69
68
70
69
// Mock Functions
71
70
72
- // MockSSHDialer is a mock implementation of SSHDialer
71
+ // MockSSHDialer is a mock implementation of SSHDialer for testing
73
72
type MockSSHDialer struct {
74
- mock.Mock
75
- }
76
-
77
- // MockSSHClient is a mock implementation of SSHClienter
78
- type MockSFTPClient struct {
79
- mock.Mock
73
+ DialFunc func (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error )
80
74
}
81
75
82
- // Dial is a mock implementation of the Dial method
83
76
func (m * MockSSHDialer ) Dial (network , addr string , config * ssh.ClientConfig ) (SSHClienter , error ) {
84
- args := m .Called (network , addr , config )
85
- if args .Get (1 ) != nil {
86
- return nil , args .Error (1 )
77
+ if m .DialFunc != nil {
78
+ return m .DialFunc (network , addr , config )
87
79
}
88
- return args . Get ( 0 ).( SSHClienter ) , nil
80
+ return nil , nil
89
81
}
90
82
91
- // DialContext is a mock implementation of the DialContext method
92
83
func (m * MockSSHDialer ) DialContext (
93
84
ctx context.Context ,
94
85
network , addr string ,
0 commit comments