-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: lstocchi <[email protected]>
- Loading branch information
Showing
2 changed files
with
262 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
//go:build windows | ||
|
||
package vsock | ||
|
||
import ( | ||
"errors" | ||
"fmt" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"golang.org/x/sys/windows/registry" | ||
) | ||
|
||
func TestCreateHVSockRegistryEntry(t *testing.T) { | ||
originalFindOpenHVSockPort := hfindOpenHVSockPort | ||
defer func() { hfindOpenHVSockPort = originalFindOpenHVSockPort }() | ||
|
||
tests := []struct { | ||
name string | ||
machineName string | ||
purpose HVSockPurpose | ||
wantPort uint64 | ||
wantErr bool | ||
findOpenHVSockPortMock func() (uint64, error) | ||
}{ | ||
{ | ||
name: "ValidInput", | ||
machineName: "test-machine", | ||
purpose: Network, | ||
wantPort: 1111, | ||
wantErr: false, | ||
findOpenHVSockPortMock: func() (uint64, error) { | ||
return 1111, nil | ||
}, | ||
}, | ||
{ | ||
name: "ErrorFindPort", | ||
machineName: "test-machine", | ||
purpose: Events, | ||
wantErr: true, | ||
findOpenHVSockPortMock: func() (uint64, error) { | ||
return 0, errors.New("error") | ||
}, | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
hfindOpenHVSockPort = tt.findOpenHVSockPortMock | ||
got, err := CreateHVSockRegistryEntry(tt.machineName, tt.purpose) | ||
|
||
if tt.wantErr { | ||
assert.Error(t, err) | ||
assert.Nil(t, got) | ||
} else { | ||
assert.NoError(t, err) | ||
assert.Equal(t, tt.wantPort, got.Port) | ||
assert.Equal(t, tt.machineName, got.MachineName) | ||
assert.Equal(t, tt.purpose, got.Purpose) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestElevateAndAddEntries(t *testing.T) { | ||
resultTestScript := make(map[string]string) | ||
|
||
tests := []struct { | ||
name string | ||
entries []HVSockRegistryEntry | ||
err error | ||
wantErr bool | ||
registryOpenKeyMock func(k registry.Key, path string, access uint32) (registry.Key, error) | ||
windowsLaunchElevatedWaitWithWindowMode func(exe string, cwd string, args string, windowMode int) error | ||
script string | ||
}{ | ||
{ | ||
name: "ErrorInvalidPort", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 0, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: errors.New("port must be larger than 1"), | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "ErrorInvalidPurpose", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: HVSockPurpose(999), | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: errors.New("required field purpose is empty"), | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "ErrorEmptyMachineName", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "", | ||
}, | ||
}, | ||
err: errors.New("required field machinename is empty"), | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "ErrorEmptyKeyName", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: errors.New("required field keypath is empty"), | ||
wantErr: true, | ||
}, | ||
{ | ||
name: "ErrorCannotOpenRegistryEntry", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: errors.New("cannot open my mocked registry key"), | ||
wantErr: true, | ||
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) { | ||
return k, errors.New("cannot open my mocked registry key") | ||
}, | ||
}, | ||
{ | ||
name: "ErrorRegistryEntryExists", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: fmt.Errorf("%q: %s", ErrVSockRegistryEntryExists, "key"), | ||
wantErr: true, | ||
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) { | ||
return k, nil | ||
}, | ||
}, | ||
{ | ||
name: "ErrorRegistryEntryExists", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
err: errors.New("error when opening parent key"), | ||
wantErr: true, | ||
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) { | ||
if path == VsockRegistryPath { | ||
return k, errors.New("error when opening parent key") | ||
} | ||
return k, registry.ErrNotExist | ||
}, | ||
}, | ||
{ | ||
name: "ValidInput", | ||
entries: []HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}, | ||
wantErr: false, | ||
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) { | ||
if path == VsockRegistryPath { | ||
return k, nil | ||
} | ||
return k, registry.ErrNotExist | ||
}, | ||
windowsLaunchElevatedWaitWithWindowMode: func(exe, cwd, args string, windowMode int) error { | ||
resultTestScript["ValidInput"] = args | ||
return nil | ||
}, | ||
script: "New-Item -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices' -Name 'key'; New-ItemProperty -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Name 'Purpose' -Value 'Events' -PropertyType String; New-ItemProperty -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Name 'MachineName' -Value 'test-machine' -PropertyType String;", | ||
}, | ||
} | ||
|
||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
if tt.registryOpenKeyMock != nil { | ||
originalegistryOpenKey := rOpenKey | ||
defer func() { rOpenKey = originalegistryOpenKey }() | ||
rOpenKey = tt.registryOpenKeyMock | ||
} | ||
if tt.windowsLaunchElevatedWaitWithWindowMode != nil { | ||
originalLaunchElevatedWaitWithWindowMode := wLaunchElevatedWaitWithWindowMode | ||
defer func() { wLaunchElevatedWaitWithWindowMode = originalLaunchElevatedWaitWithWindowMode }() | ||
wLaunchElevatedWaitWithWindowMode = tt.windowsLaunchElevatedWaitWithWindowMode | ||
} | ||
err := ElevateAndAddEntries(tt.entries) | ||
|
||
if tt.wantErr { | ||
assert.Error(t, err) | ||
assert.EqualValues(t, err, tt.err) | ||
} else { | ||
assert.NoError(t, err) | ||
resultingScript := resultTestScript[tt.name] | ||
assert.EqualValues(t, resultingScript, tt.script) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestElevateAndRemoveEntries(t *testing.T) { | ||
expectedScript := "Remove-Item -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Force -Recurse;" | ||
script := "" | ||
originalLaunchElevatedWaitWithWindowMode := wLaunchElevatedWaitWithWindowMode | ||
defer func() { wLaunchElevatedWaitWithWindowMode = originalLaunchElevatedWaitWithWindowMode }() | ||
wLaunchElevatedWaitWithWindowMode = func(exe, cwd, args string, windowMode int) error { | ||
script = args | ||
return nil | ||
} | ||
|
||
ElevateAndRemoveEntries([]HVSockRegistryEntry{ | ||
{ | ||
KeyName: "key", | ||
Purpose: Events, | ||
Port: 8888, | ||
MachineName: "test-machine", | ||
}, | ||
}) | ||
assert.EqualValues(t, script, expectedScript) | ||
} |