Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: lstocchi <[email protected]>
  • Loading branch information
lstocchi committed Jan 23, 2025
1 parent 1c2dff0 commit 38f41e8
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 7 deletions.
19 changes: 12 additions & 7 deletions pkg/machine/hyperv/vsock/vsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
)

var ErrVSockRegistryEntryExists = errors.New("registry entry already exists")
var rOpenKey = registry.OpenKey

const (
// HvsockMachineName is the string identifier for the machine name in a registry entry
Expand Down Expand Up @@ -74,7 +75,7 @@ func toHVSockPurpose(p string) (HVSockPurpose, error) {
}

func openVSockRegistryEntry(entry string) (registry.Key, error) {
return registry.OpenKey(registry.LOCAL_MACHINE, entry, registry.QUERY_VALUE)
return rOpenKey(registry.LOCAL_MACHINE, entry, registry.QUERY_VALUE)
}

// HVSockRegistryEntry describes a registry entry used in Windows for HVSOCK implementations
Expand Down Expand Up @@ -147,13 +148,15 @@ func findOpenHVSockPort() (uint64, error) {
return 0, errors.New("unable to find a free port for hvsock use")
}

var hfindOpenHVSockPort = findOpenHVSockPort

// CreateHVSockRegistryEntry is a constructor to make an instance of a registry entry in Windows. After making the new
// object, you must call the add() method or AddHVSockRegistryEntries(...) to *actually* add it to the Windows registry.
func CreateHVSockRegistryEntry(machineName string, purpose HVSockPurpose) (*HVSockRegistryEntry, error) {
// a so-called wildcard entry ... everything from FACB -> 6D3 is MS special sauce
// for a " linux vm". this first segment is hexi for the hvsock port number
// 00000400-FACB-11E6-BD58-64006A7986D3
port, err := findOpenHVSockPort()
port, err := hfindOpenHVSockPort()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -184,15 +187,15 @@ func ElevateAndAddEntries(entries []HVSockRegistryEntry) error {
if exists {
return fmt.Errorf("%q: %s", ErrVSockRegistryEntryExists, entry.KeyName)
}
parentKey, err := registry.OpenKey(registry.LOCAL_MACHINE, VsockRegistryPath, registry.QUERY_VALUE)
parentKey, err := rOpenKey(registry.LOCAL_MACHINE, VsockRegistryPath, registry.QUERY_VALUE)
if err != nil {
return err
}
defer func() {
if err := parentKey.Close(); err != nil {
logrus.Error(err)
}
}()
if err != nil {
return err
}

// for each entry it adds a purpose and machineName property
registryPath := fmt.Sprintf("HKLM:\\%s", VsockRegistryPath)
Expand Down Expand Up @@ -229,6 +232,8 @@ func ElevateAndRemoveEntries(entries []HVSockRegistryEntry) error {
return launchElevated(script)
}

var wLaunchElevatedWaitWithWindowMode = windows.LaunchElevatedWaitWithWindowMode

func launchElevated(args string) error {
psPath, err := exec.LookPath("powershell.exe")
if err != nil {
Expand All @@ -240,7 +245,7 @@ func launchElevated(args string) error {
return err
}

return windows.LaunchElevatedWaitWithWindowMode(psPath, d, args, syscall.SW_HIDE)
return wLaunchElevatedWaitWithWindowMode(psPath, d, args, syscall.SW_HIDE)
}

func portToKeyName(port uint64) string {
Expand Down
250 changes: 250 additions & 0 deletions pkg/machine/hyperv/vsock/vsock_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
//go:build windows

package vsock

import (
"errors"
"fmt"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/sys/windows/registry"
)

func TestCreateHVSockRegistryEntry(t *testing.T) {
originalFindOpenHVSockPort := hfindOpenHVSockPort
defer func() { hfindOpenHVSockPort = originalFindOpenHVSockPort }()

tests := []struct {
name string
machineName string
purpose HVSockPurpose
wantPort uint64
wantErr bool
findOpenHVSockPortMock func() (uint64, error)
}{
{
name: "ValidInput",
machineName: "test-machine",
purpose: Network,
wantPort: 1111,
wantErr: false,
findOpenHVSockPortMock: func() (uint64, error) {
return 1111, nil
},
},
{
name: "ErrorFindPort",
machineName: "test-machine",
purpose: Events,
wantErr: true,
findOpenHVSockPortMock: func() (uint64, error) {
return 0, errors.New("error")
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hfindOpenHVSockPort = tt.findOpenHVSockPortMock
got, err := CreateHVSockRegistryEntry(tt.machineName, tt.purpose)

if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantPort, got.Port)
assert.Equal(t, tt.machineName, got.MachineName)
assert.Equal(t, tt.purpose, got.Purpose)
}
})
}
}

func TestElevateAndAddEntries(t *testing.T) {
resultTestScript := make(map[string]string)

tests := []struct {
name string
entries []HVSockRegistryEntry
err error
wantErr bool
registryOpenKeyMock func(k registry.Key, path string, access uint32) (registry.Key, error)
windowsLaunchElevatedWaitWithWindowMode func(exe string, cwd string, args string, windowMode int) error
script string
}{
{
name: "ErrorInvalidPort",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 0,
MachineName: "test-machine",
},
},
err: errors.New("port must be larger than 1"),
wantErr: true,
},
{
name: "ErrorInvalidPurpose",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: HVSockPurpose(999),
Port: 8888,
MachineName: "test-machine",
},
},
err: errors.New("required field purpose is empty"),
wantErr: true,
},
{
name: "ErrorEmptyMachineName",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "",
},
},
err: errors.New("required field machinename is empty"),
wantErr: true,
},
{
name: "ErrorEmptyKeyName",
entries: []HVSockRegistryEntry{
{
KeyName: "",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
},
err: errors.New("required field keypath is empty"),
wantErr: true,
},
{
name: "ErrorCannotOpenRegistryEntry",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
},
err: errors.New("cannot open my mocked registry key"),
wantErr: true,
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) {
return k, errors.New("cannot open my mocked registry key")
},
},
{
name: "ErrorRegistryEntryExists",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
},
err: fmt.Errorf("%q: %s", ErrVSockRegistryEntryExists, "key"),
wantErr: true,
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) {
return k, nil
},
},
{
name: "ErrorRegistryEntryExists",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
},
err: errors.New("error when opening parent key"),
wantErr: true,
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) {
if path == VsockRegistryPath {
return k, errors.New("error when opening parent key")
}
return k, registry.ErrNotExist
},
},
{
name: "ValidInput",
entries: []HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
},
wantErr: false,
registryOpenKeyMock: func(k registry.Key, path string, access uint32) (registry.Key, error) {
if path == VsockRegistryPath {
return k, nil
}
return k, registry.ErrNotExist
},
windowsLaunchElevatedWaitWithWindowMode: func(exe, cwd, args string, windowMode int) error {
resultTestScript["ValidInput"] = args
return nil
},
script: "New-Item -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices' -Name 'key'; New-ItemProperty -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Name 'Purpose' -Value 'Events' -PropertyType String; New-ItemProperty -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Name 'MachineName' -Value 'test-machine' -PropertyType String;",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.registryOpenKeyMock != nil {
originalegistryOpenKey := rOpenKey
defer func() { rOpenKey = originalegistryOpenKey }()
rOpenKey = tt.registryOpenKeyMock
}
if tt.windowsLaunchElevatedWaitWithWindowMode != nil {
originalLaunchElevatedWaitWithWindowMode := wLaunchElevatedWaitWithWindowMode
defer func() { wLaunchElevatedWaitWithWindowMode = originalLaunchElevatedWaitWithWindowMode }()
wLaunchElevatedWaitWithWindowMode = tt.windowsLaunchElevatedWaitWithWindowMode
}
err := ElevateAndAddEntries(tt.entries)

if tt.wantErr {
assert.Error(t, err)
assert.EqualValues(t, err, tt.err)
} else {
assert.NoError(t, err)
resultingScript := resultTestScript[tt.name]
assert.EqualValues(t, resultingScript, tt.script)
}
})
}
}

func TestElevateAndRemoveEntries(t *testing.T) {
expectedScript := "Remove-Item -Path 'HKLM:\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion\\Virtualization\\GuestCommunicationServices\\key' -Force -Recurse;"
script := ""
originalLaunchElevatedWaitWithWindowMode := wLaunchElevatedWaitWithWindowMode
defer func() { wLaunchElevatedWaitWithWindowMode = originalLaunchElevatedWaitWithWindowMode }()
wLaunchElevatedWaitWithWindowMode = func(exe, cwd, args string, windowMode int) error {
script = args
return nil
}

ElevateAndRemoveEntries([]HVSockRegistryEntry{
{
KeyName: "key",
Purpose: Events,
Port: 8888,
MachineName: "test-machine",
},
})
assert.EqualValues(t, script, expectedScript)
}

0 comments on commit 38f41e8

Please sign in to comment.