diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..e1e5e60 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,8 @@ +linters: + enable: + - asciicheck + - bodyclose + - dupl + - errorlint + - exportloopref + - funlen \ No newline at end of file diff --git a/client_cluster.go b/client_cluster.go index 9f300b2..8d4ed10 100644 --- a/client_cluster.go +++ b/client_cluster.go @@ -1,8 +1,6 @@ package ovirtclient import ( - "fmt" - ovirtsdk4 "github.com/ovirt/go-ovirt" ) @@ -25,12 +23,12 @@ type Cluster interface { func convertSDKCluster(sdkCluster *ovirtsdk4.Cluster) (Cluster, error) { id, ok := sdkCluster.Id() if !ok { - return nil, fmt.Errorf("failed to fetch ID for cluster") + return nil, newError(EFieldMissing, "failed to fetch ID for cluster") } name, ok := sdkCluster.Name() if !ok { - return nil, fmt.Errorf("failed to fetch name for cluster %s", id) + return nil, newError(EFieldMissing, "failed to fetch name for cluster %s", id) } return &cluster{ id: id, diff --git a/client_cluster_get.go b/client_cluster_get.go index b9b444e..1b713eb 100644 --- a/client_cluster_get.go +++ b/client_cluster_get.go @@ -1,21 +1,26 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) GetCluster(id string) (cluster Cluster, err error) { response, err := o.conn.SystemService().ClustersService().ClusterService(id).Get().Send() if err != nil { - return nil, fmt.Errorf("failed to fetch cluster ID %s (%w)", id, err) + return nil, wrap(err, "failed to fetch cluster ID %s", id) } sdkCluster, ok := response.Cluster() if !ok { - return nil, fmt.Errorf("no cluster returned when getting cluster ID %s", id) + return nil, newError( + ENotFound, + "no cluster returned when getting cluster ID %s", + id, + ) } cluster, err = convertSDKCluster(sdkCluster) if err != nil { - return nil, fmt.Errorf("failed to convert cluster %s (%w)", id, err) + return nil, wrap( + err, + EBug, + "failed to convert cluster %s", + id, + ) } return cluster, nil } diff --git a/client_cluster_list.go b/client_cluster_list.go index 78edf77..9598eb7 100644 --- a/client_cluster_list.go +++ b/client_cluster_list.go @@ -1,26 +1,23 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) ListClusters() ([]Cluster, error) { clustersResponse, err := o.conn.SystemService().ClustersService().List().Send() if err != nil { - return nil, fmt.Errorf( - "failed to list oVirt clusters (%w)", + return nil, wrap( err, + EUnidentified, + "failed to list oVirt clusters", ) } sdkClusters, ok := clustersResponse.Clusters() if !ok { - return nil, fmt.Errorf("no clusters returned from clusters list API call") + return []Cluster{}, nil } clusters := make([]Cluster, len(sdkClusters.Slice())) for i, sdkCluster := range sdkClusters.Slice() { clusters[i], err = convertSDKCluster(sdkCluster) if err != nil { - return nil, fmt.Errorf("failed to convert cluster during cluster listing item %d (%w)", i, err) + return nil, wrap(err, EBug, "failed to convert cluster during cluster listing item %d", i) } } return clusters, nil diff --git a/client_disk.go b/client_disk.go index 1802eaa..bad46fc 100644 --- a/client_disk.go +++ b/client_disk.go @@ -2,14 +2,13 @@ package ovirtclient import ( "context" - "fmt" "io" ovirtsdk4 "github.com/ovirt/go-ovirt" ) type DiskClient interface { - // UploadImage uploads an image file into a disk. The actual upload takes place in the + // StartImageUpload uploads an image file into a disk. The actual upload takes place in the // background and can be tracked using the returned UploadImageProgress object. // // Parameters are as follows: @@ -57,22 +56,31 @@ type DiskClient interface { // ListDisks lists all disks. ListDisks() ([]Disk, error) - // GetDisk fetches a disk with a specific ID from the + // GetDisk fetches a disk with a specific ID from the oVirt Engine. GetDisk(diskID string) (Disk, error) // RemoveDisk removes a disk with a specific ID. - RemoveDisk(diskID string) error + RemoveDisk(ctx context.Context, diskID string) error } +// UploadImageResult represents the completed image upload. type UploadImageResult interface { + // Disk returns the disk that has been created as the result of the image upload. Disk() Disk + // CorrelationID returns the opaque correlation ID for the upload. CorrelationID() string } +// Disk is a disk in oVirt. type Disk interface { + // ID is the unique ID for this disk. ID() string + // Alias is the name for this disk set by the user. Alias() string + // ProvisionedSize is the size visible to the virtual machine. ProvisionedSize() uint64 + // Format is the format of the image. Format() ImageFormat + // StorageDomainID is the ID of the storage system used for this disk. StorageDomainID() string } @@ -104,7 +112,7 @@ const ( func convertSDKDisk(sdkDisk *ovirtsdk4.Disk) (Disk, error) { id, ok := sdkDisk.Id() if !ok { - return nil, fmt.Errorf("disk does not contain an ID") + return nil, newError(EFieldMissing, "disk does not contain an ID") } var storageDomainID string if sdkStorageDomain, ok := sdkDisk.StorageDomain(); ok { @@ -118,19 +126,19 @@ func convertSDKDisk(sdkDisk *ovirtsdk4.Disk) (Disk, error) { } } if storageDomainID == "" { - return nil, fmt.Errorf("failed to find a valid storage domain ID for disk %s", id) + return nil, newError(EFieldMissing, "failed to find a valid storage domain ID for disk %s", id) } alias, ok := sdkDisk.Alias() if !ok { - return nil, fmt.Errorf("disk %s does not contain an alias", id) + return nil, newError(EFieldMissing, "disk %s does not contain an alias", id) } provisionedSize, ok := sdkDisk.ProvisionedSize() if !ok { - return nil, fmt.Errorf("disk %s does not contain a provisioned size", id) + return nil, newError(EFieldMissing, "disk %s does not contain a provisioned size", id) } format, ok := sdkDisk.Format() if !ok { - return nil, fmt.Errorf("disk %s has no format field", id) + return nil, newError(EFieldMissing, "disk %s has no format field", id) } return &disk{ id: id, diff --git a/client_disk_get.go b/client_disk_get.go index 217cdb4..a73f7b2 100644 --- a/client_disk_get.go +++ b/client_disk_get.go @@ -1,21 +1,32 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) GetDisk(diskID string) (Disk, error) { response, err := o.conn.SystemService().DisksService().DiskService(diskID).Get().Send() if err != nil { - return nil, fmt.Errorf("failed to fetch disk %s (%w)", diskID, err) + return nil, wrap( + err, + EUnidentified, + "failed to fetch disk %s", + diskID, + ) } sdkDisk, ok := response.Disk() if !ok { - return nil, fmt.Errorf("disk %s response did not contain a disk (%w)", diskID, err) + return nil, wrap( + err, + ENotFound, + "disk %s response did not contain a disk", + diskID, + ) } disk, err := convertSDKDisk(sdkDisk) if err != nil { - return nil, fmt.Errorf("failed to convert disk %s (%w)", diskID, err) + return nil, wrap( + err, + EBug, + "failed to convert disk %s", + diskID, + ) } return disk, nil } diff --git a/client_disk_list.go b/client_disk_list.go index f573713..13f5229 100644 --- a/client_disk_list.go +++ b/client_disk_list.go @@ -1,23 +1,23 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) ListDisks() ([]Disk, error) { response, err := o.conn.SystemService().DisksService().List().Send() if err != nil { - return nil, fmt.Errorf("failed to list disks (%w)", err) + return nil, wrap( + err, + EUnidentified, + "failed to list disks", + ) } sdkDisks, ok := response.Disks() if !ok { - return nil, fmt.Errorf("disk list response does not contain disks") + return []Disk{}, nil } result := make([]Disk, len(sdkDisks.Slice())) for i, sdkDisk := range sdkDisks.Slice() { disk, err := convertSDKDisk(sdkDisk) if err != nil { - return nil, fmt.Errorf("failed to convert disk %d (%w)", i, err) + return nil, wrap(err, EBug, "failed to convert disk item %d", i) } result[i] = disk } diff --git a/client_disk_remove.go b/client_disk_remove.go index 84878d4..7ab7daf 100644 --- a/client_disk_remove.go +++ b/client_disk_remove.go @@ -1,12 +1,34 @@ package ovirtclient import ( - "fmt" + "context" + "time" ) -func (o *oVirtClient) RemoveDisk(diskID string) error { - if _, err := o.conn.SystemService().DisksService().DiskService(diskID).Remove().Send(); err != nil { - return fmt.Errorf("failed to remove disk %s (%w)", diskID, err) +func (o *oVirtClient) RemoveDisk(ctx context.Context, diskID string) error { + var lastError EngineError + for { + _, err := o.conn.SystemService().DisksService().DiskService(diskID).Remove().Send() + if err == nil { + return err + } + lastError = wrap( + err, + EUnidentified, + "failed to remove disk %s", + diskID, + ) + if !lastError.CanAutoRetry() { + return lastError + } + select { + case <-ctx.Done(): + return wrap( + lastError, + ETimeout, + "timeout while removing disk", + ) + case <-time.After(10 * time.Second): + } } - return nil } diff --git a/client_disk_uploadimage.go b/client_disk_uploadimage.go index deeb2d4..cf76161 100644 --- a/client_disk_uploadimage.go +++ b/client_disk_uploadimage.go @@ -56,7 +56,11 @@ func (o *oVirtClient) StartImageUpload( disk, err := o.createDiskForUpload(storageDomainID, alias, format, qcowSize, sparse, cancel) if err != nil { - return nil, err + return nil, wrap( + err, + EUnidentified, + "failed to create disk for image upload", + ) } return o.createProgress(alias, qcowSize, size, bufReader, storageDomainID, sparse, newCtx, cancel, disk) @@ -72,7 +76,12 @@ func (o *oVirtClient) createDiskForUpload( ) (*ovirtsdk4.Disk, error) { storageDomain, err := ovirtsdk4.NewStorageDomainBuilder().Id(storageDomainID).Build() if err != nil { - panic(fmt.Errorf("bug: failed to build storage domain object from storage domain ID: %s", storageDomainID)) + return nil, wrap( + err, + EBug, + "failed to build storage domain object from storage domain ID: %s", + storageDomainID, + ) } diskBuilder := ovirtsdk4.NewDiskBuilder(). Alias(alias). @@ -84,13 +93,13 @@ func (o *oVirtClient) createDiskForUpload( disk, err := diskBuilder.Build() if err != nil { cancel() - return nil, fmt.Errorf( - //nolint:govet - "failed to build disk with alias %s, format %s, provisioned and initial size %d (%w)", + return nil, wrap( + err, + EBug, + "failed to build disk with alias %s, format %s, provisioned and initial size %d", alias, format, qcowSize, - err, ) } return disk, nil @@ -108,6 +117,7 @@ func (o *oVirtClient) createProgress( disk *ovirtsdk4.Disk, ) (UploadImageProgress, error) { progress := &uploadImageProgress{ + cli: o, correlationID: fmt.Sprintf("image_transfer_%s", alias), uploadedBytes: 0, cowSize: qcowSize, @@ -131,6 +141,7 @@ func (o *oVirtClient) createProgress( } type uploadImageProgress struct { + cli *oVirtClient uploadedBytes uint64 cowSize uint64 size uint64 @@ -146,7 +157,7 @@ type uploadImageProgress struct { done chan struct{} // lock is a lock that prevents race conditions during the upload process. lock *sync.Mutex - // cancel is the cancel function for the context. Is is called to ensure that the context is properly canceled. + // cancel is the cancel function for the context. HasCode is called to ensure that the context is properly canceled. cancel context.CancelFunc // err holds the error that happened during the upload. It can be queried using the Err() method. err error @@ -174,7 +185,7 @@ func (u *uploadImageProgress) Disk() Disk { } disk, err := convertSDKDisk(sdkDisk) if err != nil { - panic(fmt.Errorf("bug: failed to convert disk (%w)", err)) + panic(wrap(err, EBug, "bug: failed to convert disk")) } return disk } @@ -203,7 +214,7 @@ func (u *uploadImageProgress) Done() <-chan struct{} { func (u *uploadImageProgress) Read(p []byte) (n int, err error) { select { case <-u.ctx.Done(): - return 0, fmt.Errorf("timeout while uploading image") + return 0, newError(ETimeout, "timeout while uploading image") default: } n, err = u.reader.Read(p) @@ -272,7 +283,7 @@ func (u *uploadImageProgress) removeDisk() { disk := u.disk if disk != nil { if id, ok := u.disk.Id(); ok { - _ = u.client.RemoveDisk(id) + _ = u.client.RemoveDisk(u.ctx, id) } } } @@ -284,7 +295,7 @@ func (u *uploadImageProgress) finalizeUpload( finalizeRequest.Query("correlation_id", u.correlationID) _, err := finalizeRequest.Send() if err != nil { - return fmt.Errorf("failed to finalize image upload (%w)", err) + return wrap(err, EUnidentified, "failed to finalize image upload") } return nil } @@ -292,16 +303,16 @@ func (u *uploadImageProgress) finalizeUpload( func (u *uploadImageProgress) uploadImage(transferURL *url.URL) error { putRequest, err := http.NewRequest(http.MethodPut, transferURL.String(), u) if err != nil { - return fmt.Errorf("failed to create HTTP request (%w)", err) + return wrap(err, EUnidentified, "failed to create HTTP request") } putRequest.Header.Add("content-type", "application/octet-stream") putRequest.ContentLength = int64(u.size) response, err := u.httpClient.Do(putRequest) if err != nil { - return fmt.Errorf("failed to upload image (%w)", err) + return wrap(err, EUnidentified, "failed to upload image") } if err := response.Body.Close(); err != nil { - return fmt.Errorf("failed to close response body while uploading image (%w)", err) + return wrap(err, EUnidentified, "failed to close response body while uploading image") } return nil } @@ -316,7 +327,7 @@ func (u *uploadImageProgress) findTransferURL(transfer *ovirtsdk4.ImageTransfer) } if len(tryURLs) == 0 { - return nil, fmt.Errorf("neither a transfer URL nor a proxy URL was returned from the oVirt Engine") + return nil, newError(EBug, "neither a transfer URL nor a proxy URL was returned from the oVirt Engine") } var foundTransferURL *url.URL @@ -324,7 +335,7 @@ func (u *uploadImageProgress) findTransferURL(transfer *ovirtsdk4.ImageTransfer) for _, transferURL := range tryURLs { transferURL, err := url.Parse(transferURL) if err != nil { - lastError = fmt.Errorf("failer to parse transfer URL %s (%w)", transferURL, err) + lastError = wrap(err, EUnidentified, "failed to parse transfer URL %s", transferURL) continue } @@ -339,14 +350,14 @@ func (u *uploadImageProgress) findTransferURL(transfer *ovirtsdk4.ImageTransfer) if err == nil { statusCode := res.StatusCode if err := res.Body.Close(); err != nil { - lastError = fmt.Errorf("failed to close response body in options request (%w)", err) + lastError = wrap(err, EUnidentified, "failed to close response body in options request") } else { if statusCode == 200 { foundTransferURL = transferURL lastError = nil break } else { - lastError = fmt.Errorf("non-200 status code returned from URL %s (%d)", hostUrl, res.StatusCode) + lastError = newError(EConnection, "non-200 status code returned from URL %s (%d)", hostUrl, res.StatusCode) } } } else { @@ -357,7 +368,7 @@ func (u *uploadImageProgress) findTransferURL(transfer *ovirtsdk4.ImageTransfer) } } if foundTransferURL == nil { - return nil, fmt.Errorf("failed to find transfer URL (last error: %w)", lastError) + return nil, wrap(lastError, EUnidentified, "failed to find transfer URL") } return foundTransferURL, nil } @@ -368,7 +379,7 @@ func (u *uploadImageProgress) createDisk() (string, *ovirtsdk4.DiskService, erro addResp, err := addDiskRequest.Send() if err != nil { diskAlias, _ := u.disk.Alias() - return "", nil, fmt.Errorf("failed to create disk, alias: %s (%w)", diskAlias, err) + return "", nil, wrap(err, EUnidentified, "failed to create disk, alias: %s", diskAlias) } diskID := addResp.MustDisk().MustId() diskService := u.conn.SystemService().DisksService().DiskService(diskID) @@ -380,6 +391,7 @@ func (u *uploadImageProgress) setupImageTransfer(diskID string) ( *ovirtsdk4.ImageTransferService, error, ) { + var lastError EngineError imageTransfersService := u.conn.SystemService().ImageTransfersService() image := ovirtsdk4.NewImageBuilder().Id(diskID).MustBuild() transfer := ovirtsdk4. @@ -392,27 +404,33 @@ func (u *uploadImageProgress) setupImageTransfer(diskID string) ( Query("correlation_id", u.correlationID) transferRes, err := transferReq.Send() if err != nil { - return nil, nil, fmt.Errorf("failed to start image transfer (%w)", err) + return nil, nil, wrap(err, EUnidentified, "failed to start image transfer") } transfer = transferRes.MustImageTransfer() transferService := imageTransfersService.ImageTransferService(transfer.MustId()) for { - req, lastError := transferService.Get().Send() - if lastError == nil { + req, err := transferService.Get().Send() + if err == nil { if req.MustImageTransfer().MustPhase() == ovirtsdk4.IMAGETRANSFERPHASE_TRANSFERRING { break } else { - lastError = fmt.Errorf( + lastError = newError( + EPending, "image transfer is in phase %s instead of transferring", req.MustImageTransfer().MustPhase(), ) } + } else { + lastError = wrap(err, EUnidentified, "failed to get image transfer for disk %s", diskID) + if !lastError.CanAutoRetry() { + return nil, nil, lastError + } } select { case <-time.After(time.Second * 5): case <-u.ctx.Done(): - return nil, nil, fmt.Errorf("timeout while waiting for image transfer (last error was: %w)", lastError) + return nil, nil, wrap(lastError, ETimeout, "timeout while waiting for image transfer") } } return transfer, transferService, nil @@ -425,12 +443,12 @@ func (u *uploadImageProgress) waitForDiskOk(diskService *ovirtsdk4.DiskService) if err == nil { disk, ok := req.Disk() if !ok { - return fmt.Errorf("the disk was removed after upload, probably not supported") + return newError(EUnsupported, "the disk was removed after upload, probably not supported") } if disk.MustStatus() == ovirtsdk4.DISKSTATUS_OK { - return nil + return u.cli.waitForJobFinished(u.ctx, u.correlationID) } else { - lastError = fmt.Errorf("disk status is %s, not ok", disk.MustStatus()) + lastError = newError(EPending, "disk status is %s, not ok", disk.MustStatus()) } u.disk = disk } else { @@ -439,7 +457,7 @@ func (u *uploadImageProgress) waitForDiskOk(diskService *ovirtsdk4.DiskService) select { case <-time.After(5 * time.Second): case <-u.ctx.Done(): - return fmt.Errorf("timeout while waiting for disk to be ok after upload (last error: %w)", lastError) + return wrap(lastError, ETimeout, "timeout while waiting for disk to be ok after upload") } } } diff --git a/client_disk_uploadimage_test.go b/client_disk_uploadimage_test.go index a53f7e2..bec505f 100644 --- a/client_disk_uploadimage_test.go +++ b/client_disk_uploadimage_test.go @@ -46,7 +46,7 @@ func TestImageUploadDiskCreated(t *testing.T) { if err != nil { t.Fatal(fmt.Errorf("failed to fetch disk after image upload (%w)", err)) } - if err := client.RemoveDisk(disk.ID()); err != nil { - t.Fatal(fmt.Errorf("failed to remove disk (%w)", err)) + if err := client.RemoveDisk(context.Background(), disk.ID()); err != nil { + t.Fatal(err) } } diff --git a/client_host.go b/client_host.go index 4810284..99e4745 100644 --- a/client_host.go +++ b/client_host.go @@ -1,8 +1,6 @@ package ovirtclient import ( - "fmt" - ovirtsdk4 "github.com/ovirt/go-ovirt" ) @@ -43,19 +41,19 @@ const ( func convertSDKHost(sdkHost *ovirtsdk4.Host) (Host, error) { id, ok := sdkHost.Id() if !ok { - return nil, fmt.Errorf("returned host did not contain an ID") + return nil, newError(EFieldMissing, "returned host did not contain an ID") } status, ok := sdkHost.Status() if !ok { - return nil, fmt.Errorf("returned host did not contain a status") + return nil, newError(EFieldMissing, "returned host did not contain a status") } sdkCluster, ok := sdkHost.Cluster() if !ok { - return nil, fmt.Errorf("returned host did not contain a cluster") + return nil, newError(EFieldMissing, "returned host did not contain a cluster") } clusterID, ok := sdkCluster.Id() if !ok { - return nil, fmt.Errorf("failed to fetch cluster ID from host %s", id) + return nil, newError(EFieldMissing, "failed to fetch cluster ID from host %s", id) } return &host{ id: id, diff --git a/client_host_get.go b/client_host_get.go index 92fc03b..7215a41 100644 --- a/client_host_get.go +++ b/client_host_get.go @@ -1,21 +1,31 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) GetHost(id string) (Host, error) { response, err := o.conn.SystemService().HostsService().HostService(id).Get().Send() if err != nil { - return nil, fmt.Errorf("failed to fetch host %s (%w)", id, err) + return nil, wrap( + err, + EUnidentified, + "failed to fetch host %s", + id, + ) } sdkHost, ok := response.Host() if !ok { - return nil, fmt.Errorf("API response contained no host") + return nil, wrap( + err, + ENotFound, + "host %s response did not contain a host", + ) } host, err := convertSDKHost(sdkHost) if err != nil { - return nil, fmt.Errorf("failed to convert host object (%w)", err) + return nil, wrap( + err, + EBug, + "failed to convert host %s", + id, + ) } return host, nil } diff --git a/client_host_list.go b/client_host_list.go index 6656fe5..f4ad245 100644 --- a/client_host_list.go +++ b/client_host_list.go @@ -1,23 +1,19 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) ListHosts() ([]Host, error) { response, err := o.conn.SystemService().HostsService().List().Send() if err != nil { - return nil, fmt.Errorf("failed to list hosts (%w)", err) + return nil, wrap(err, EUnidentified, "failed to list hosts") } sdkHosts, ok := response.Hosts() if !ok { - return nil, fmt.Errorf("host list response didn't contain hosts") + return []Host{}, nil } result := make([]Host, len(sdkHosts.Slice())) for i, sdkHost := range sdkHosts.Slice() { result[i], err = convertSDKHost(sdkHost) if err != nil { - return nil, fmt.Errorf("failed to convert host %d in listing (%w)", i, err) + return nil, wrap(err, EBug, "failed to convert host item %d", i) } } return result, nil diff --git a/client_storagedomain.go b/client_storagedomain.go index 49be5fe..d29c0e3 100644 --- a/client_storagedomain.go +++ b/client_storagedomain.go @@ -1,20 +1,22 @@ package ovirtclient import ( - "fmt" - ovirtsdk4 "github.com/ovirt/go-ovirt" ) // StorageDomainClient contains the portion of the goVirt API that deals with storage domains. type StorageDomainClient interface { + // ListStorageDomains lists all storage domains. ListStorageDomains() ([]StorageDomain, error) + // GetStorageDomain returns a single storage domain, or an error if the storage domain could not be found. GetStorageDomain(id string) (StorageDomain, error) } // StorageDomain represents a storage domain returned from the oVirt Engine API. type StorageDomain interface { + // ID is the unique identified for the storage system connected to oVirt. ID() string + // Name is the user-given name for the storage domain. Name() string // Available returns the number of available bytes on the storage domain Available() uint64 @@ -55,11 +57,11 @@ const ( func convertSDKStorageDomain(sdkStorageDomain *ovirtsdk4.StorageDomain) (StorageDomain, error) { id, ok := sdkStorageDomain.Id() if !ok { - return nil, fmt.Errorf("failed to fetch ID of storage domain") + return nil, newError(EFieldMissing, "failed to fetch ID of storage domain") } name, ok := sdkStorageDomain.Name() if !ok { - return nil, fmt.Errorf("failed to fetch name of storage domain") + return nil, newError(EFieldMissing, "failed to fetch name of storage domain") } available, ok := sdkStorageDomain.Available() if !ok { @@ -67,14 +69,14 @@ func convertSDKStorageDomain(sdkStorageDomain *ovirtsdk4.StorageDomain) (Storage available = 0 } if available < 0 { - return nil, fmt.Errorf("invalid available bytes returned from storage domain: %d", available) + return nil, newError(EBug, "invalid available bytes returned from storage domain: %d", available) } // It is OK for the storage domain status to not be present if the external status is present. status, _ := sdkStorageDomain.Status() // It is OK for the storage domain external status to not be present if the status is present. externalStatus, _ := sdkStorageDomain.ExternalStatus() if status == "" && externalStatus == "" { - return nil, fmt.Errorf("neither the status nor the external status is set for storage domain %s", id) + return nil, newError(EFieldMissing, "neither the status nor the external status is set for storage domain %s", id) } return &storageDomain{ diff --git a/client_storagedomain_get.go b/client_storagedomain_get.go index 58508ce..1f1019a 100644 --- a/client_storagedomain_get.go +++ b/client_storagedomain_get.go @@ -1,21 +1,17 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) GetStorageDomain(id string) (storageDomain StorageDomain, err error) { response, err := o.conn.SystemService().StorageDomainsService().StorageDomainService(id).Get().Send() if err != nil { - return nil, fmt.Errorf("failed to get storage domain %s (%w)", id, err) + return nil, wrap(err, EUnidentified, "failed to get storage domain %s", id) } sdkStorageDomain, ok := response.StorageDomain() if !ok { - return nil, fmt.Errorf("response did not contain a storage domain") + return nil, newError(ENotFound, "response did not contain a storage domain") } storageDomain, err = convertSDKStorageDomain(sdkStorageDomain) if err != nil { - return nil, fmt.Errorf("failed to convert storage domain (%w)", err) + return nil, wrap(err, EUnidentified, "failed to convert storage domain") } return storageDomain, nil } diff --git a/client_storagedomain_list.go b/client_storagedomain_list.go index 8300ec4..c5cdfd9 100644 --- a/client_storagedomain_list.go +++ b/client_storagedomain_list.go @@ -1,23 +1,19 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) ListStorageDomains() (storageDomains []StorageDomain, err error) { response, err := o.conn.SystemService().StorageDomainsService().List().Send() if err != nil { - return nil, fmt.Errorf("failed to list storage domains (%w)", err) + return nil, wrap(err, EUnidentified, "failed to list storage domains") } sdkStorageDomains, ok := response.StorageDomains() if !ok { - return nil, fmt.Errorf("API call did not return storage domains in response") + return []StorageDomain{}, nil } storageDomains = make([]StorageDomain, len(sdkStorageDomains.Slice())) for i, sdkStorageDomain := range sdkStorageDomains.Slice() { storageDomain, err := convertSDKStorageDomain(sdkStorageDomain) if err != nil { - return nil, fmt.Errorf("failed to convert storage domain %d in listing (%w)", i, err) + return nil, wrap(err, EBug, "failed to convert storage domain %d in listing", i) } storageDomains[i] = storageDomain } diff --git a/client_template.go b/client_template.go index aef1a3d..780b10e 100644 --- a/client_template.go +++ b/client_template.go @@ -1,8 +1,6 @@ package ovirtclient import ( - "fmt" - ovirtsdk4 "github.com/ovirt/go-ovirt" ) @@ -20,15 +18,15 @@ type Template interface { func convertSDKTemplate(sdkTemplate *ovirtsdk4.Template) (Template, error) { id, ok := sdkTemplate.Id() if !ok { - return nil, fmt.Errorf("template does not contain ID") + return nil, newError(EFieldMissing, "template does not contain ID") } name, ok := sdkTemplate.Name() if !ok { - return nil, fmt.Errorf("template does not contain a name") + return nil, newError(EFieldMissing, "template does not contain a name") } description, ok := sdkTemplate.Description() if !ok { - return nil, fmt.Errorf("template does not contain a description") + return nil, newError(EFieldMissing, "template does not contain a description") } return &template{ id: id, diff --git a/client_template_get.go b/client_template_get.go index f0320ec..0c6a398 100644 --- a/client_template_get.go +++ b/client_template_get.go @@ -1,21 +1,17 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) GetTemplate(id string) (Template, error) { response, err := o.conn.SystemService().TemplatesService().TemplateService(id).Get().Send() if err != nil { - return nil, fmt.Errorf("failed to fetch template %s (%w)", id, err) + return nil, wrap(err, EUnidentified, "failed to fetch template %s", id) } sdkTemplate, ok := response.Template() if !ok { - return nil, fmt.Errorf("API response contained no host") + return nil, newError(ENotFound, "API response contained no template") } template, err := convertSDKTemplate(sdkTemplate) if err != nil { - return nil, fmt.Errorf("failed to convert template object (%w)", err) + return nil, wrap(err, EBug, "failed to convert template object") } return template, nil } diff --git a/client_template_list.go b/client_template_list.go index 6f04f13..0ac6792 100644 --- a/client_template_list.go +++ b/client_template_list.go @@ -1,23 +1,19 @@ package ovirtclient -import ( - "fmt" -) - func (o *oVirtClient) ListTemplates() ([]Template, error) { response, err := o.conn.SystemService().TemplatesService().List().Send() if err != nil { - return nil, fmt.Errorf("failed to list templates (%w)", err) + return nil, wrap(err, EUnidentified, "failed to list templates") } sdkTemplates, ok := response.Templates() if !ok { - return nil, fmt.Errorf("host list response didn't contain hosts") + return []Template{}, nil } result := make([]Template, len(sdkTemplates.Slice())) for i, sdkTemplate := range sdkTemplates.Slice() { result[i], err = convertSDKTemplate(sdkTemplate) if err != nil { - return nil, fmt.Errorf("failed to convert host %d in listing (%w)", i, err) + return nil, wrap(err, EBug, "failed to convert host %d in listing", i) } } return result, nil diff --git a/client_util.go b/client_util.go new file mode 100644 index 0000000..728fe52 --- /dev/null +++ b/client_util.go @@ -0,0 +1,58 @@ +package ovirtclient + +import ( + "context" + "fmt" + "time" + + ovirtsdk "github.com/ovirt/go-ovirt" +) + +// waitForJobFinished waits for a job to truly finish. This is especially important when disks are involved as their +// status changes to OK prematurely. +// +// correlationID is a query parameter assigned to a job before it is sent to the ovirt engine, it must be unique and +// under 30 chars. To set a correlationID add `Query("correlation_id", correlationID)` to the engine API call, for +// example: +// +// correlationID := fmt.Sprintf("image_transfer_%s", utilrand.String(5)) +// conn. +// SystemService(). +// DisksService(). +// DiskService(diskId). +// Update(). +// Query("correlation_id", correlationID). +// Send() +func (o *oVirtClient) waitForJobFinished(ctx context.Context, correlationID string) error { + var lastError EngineError + for { + jobResp, err := o.conn.SystemService().JobsService().List().Search(fmt.Sprintf("correlation_id=%s", correlationID)).Send() + if err == nil { + if jobSlice, ok := jobResp.Jobs(); ok { + if len(jobSlice.Slice()) == 0 { + return nil + } + for _, job := range jobSlice.Slice() { + if status, _ := job.Status(); status != ovirtsdk.JOBSTATUS_STARTED { + return nil + } + } + } + lastError = newError(EPending, "job for correlation ID %s still pending", correlationID) + } else { + realErr := wrap(err, EUnidentified, "failed to list jobs for correlation ID %s", correlationID) + if !realErr.CanAutoRetry() { + return realErr + } + lastError = realErr + } + select { + case <-time.After(5 * time.Second): + case <-ctx.Done(): + return wrap( + lastError, + ETimeout, + "timeout while waiting for job with correlation_id %s to finish", correlationID) + } + } +} diff --git a/client_vm.go b/client_vm.go index b4ce0f3..f87cc9f 100644 --- a/client_vm.go +++ b/client_vm.go @@ -2,7 +2,6 @@ package ovirtclient import ( "context" - "fmt" ) type VMClient interface { @@ -19,13 +18,13 @@ type VMClient interface { // is 0. If the parameters are guaranteed to be non-zero MustNewVMCPUTopo should be used. func NewVMCPUTopo(cores uint, threads uint, sockets uint) (VMCPUTopo, error) { if cores == 0 { - return nil, fmt.Errorf("BUG: cores cannot be zero") + return nil, newError(EBadArgument, "cores cannot be zero") } if threads == 0 { - return nil, fmt.Errorf("BUG: threads cannot be zero") + return nil, newError(EBadArgument, "threads cannot be zero") } if sockets == 0 { - return nil, fmt.Errorf("BUG: sockets cannot be zero") + return nil, newError(EBadArgument, "sockets cannot be zero") } return &vmCPUTopo{ cores: cores, diff --git a/credential_shift_test.go b/credential_shift_test.go new file mode 100644 index 0000000..09e6a81 --- /dev/null +++ b/credential_shift_test.go @@ -0,0 +1,72 @@ +package ovirtclient_test + +import ( + "crypto/x509" + "errors" + "fmt" + "testing" + + ovirtsdk4 "github.com/ovirt/go-ovirt" + ovirtclient "github.com/ovirt/go-ovirt-client" +) + +func TestCredentialChangeAfterSetup(t *testing.T) { + // Real CA is the CA we will use in the server + realCAPrivKey, realCACert, realCABytes, err := createCA() + if err != nil { + t.Fatalf("failed to create real CA (%v)", err) + } + + serverPrivKey, serverCert, err := createSignedCert( + []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + realCAPrivKey, + realCACert, + ) + if err != nil { + t.Fatalf("failed to create server certificate (%v)", err) + } + + port := getNextFreePort() + + srv, err := newTestServer(port, serverCert, serverPrivKey, &unauthorizedHandler{}) + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer srv.Stop() + + logger := ovirtclient.NewGoTestLogger(t) + conn, err := ovirtclient.NewWithVerify( + fmt.Sprintf("https://127.0.0.1:%d", port), + "nonexistent@internal", + "invalid-password-for-testing-purposes", + "", + realCABytes, + false, + nil, + logger, + func(connection *ovirtsdk4.Connection) error { + // Disable connection check on setup to simulate a credential shift after the connection + // has been established. + return nil + }, + ) + if err != nil { + t.Fatalf("failed to set up connection (%v)", err) + } + + _, err = conn.ListStorageDomains() + if err == nil { + t.Fatalf("listing storage domains did not result in an error") + } + var e ovirtclient.EngineError + if errors.As(err, &e) { + if e.Code() != ovirtclient.EAccessDenied { + t.Fatalf("the returned error was not an EAccessDenied (%v)", err) + } + } else { + t.Fatalf("the returned error was not an EngineError (%v)", err) + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..d282f58 --- /dev/null +++ b/errors.go @@ -0,0 +1,217 @@ +package ovirtclient + +import ( + "errors" + "fmt" + "strings" + + ovirtsdk "github.com/ovirt/go-ovirt" +) + +// ErrorCode is a code that can be used to identify error types. These errors are identified on a best effort basis +// from the underlying oVirt connection. +type ErrorCode string + +// EAccessDenied signals that the provided credentials for the oVirt engine were incorrect. +const EAccessDenied ErrorCode = "access_denied" + +// ENotAnOVirtEngine signals that the server did not respond with a proper oVirt response. +// the cre +const ENotAnOVirtEngine ErrorCode = "not_ovirt_engine" + +// ETLSError signals that the provided CA certificate did not match the server that was attempted to connect. +const ETLSError ErrorCode = "tls_error" + +// ENotFound signals that the resource requested was not found. +const ENotFound ErrorCode = "not_found" + +// EBug signals an error that should never happen. Please report this. +const EBug ErrorCode = "bug" + +// EConnection signals a problem with the connection. +const EConnection ErrorCode = "connection" + +// EPending signals that the client library is still waiting for an action to be completed. +const EPending ErrorCode = "pending" + +// ETimeout signals that the client library has timed out waiting for an action to be completed. +const ETimeout ErrorCode = "timeout" + +// EFieldMissing indicates that the oVirt API did not return a specific field. This is most likely a bug, please report +// it. +const EFieldMissing ErrorCode = "field_missing" + +// EBadArgument indicates that an input parameter was incorrect. +const EBadArgument ErrorCode = "bad_argument" + +// EFileReadFailed indicates that reading a local file failed. +const EFileReadFailed ErrorCode = "file_read_failed" + +// EUnidentified is an unidentified oVirt error. When passed to the wrap() function this error code will cause the +// wrap function to look at the wrapped error and either fetch the error code from that error, or identify the error +// from its text. +// +// If you see this error type in a log please report this error so we can add an error code for it. +const EUnidentified ErrorCode = "generic_error" + +// EUnsupported signals that an action is not supported. This can indicate a disk format or a combination of parameters. +const EUnsupported ErrorCode = "unsupported" + +// CanAutoRetry returns false if the given error code is permanent and an automatic retry should not be attempted. +func (e ErrorCode) CanAutoRetry() bool { + switch e { + case EAccessDenied: + return false + case ENotAnOVirtEngine: + return false + case ETLSError: + return false + case ENotFound: + return false + case EBug: + return false + case EUnsupported: + return false + case EFieldMissing: + return false + default: + return true + } +} + +// EngineError is an error representation for errors received while interacting with the oVirt engine. +// +// Usage: +// +// if err != nil { +// var realErr ovirtclient.EngineError +// if errors.As(err, &realErr) { +// // deal with EngineError +// } else { +// // deal with other errors +// } +// } +type EngineError interface { + error + + // String returns the string representation for this error. + String() string + // HasCode returns true if the current error, or any preceding error has the specified error code. + HasCode(ErrorCode) bool + // Code returns an error code for the failure. + Code() ErrorCode + // Unwrap returns the underlying error + Unwrap() error + // CanAutoRetry returns false if an automatic retry should not be attempted. + CanAutoRetry() bool +} + +type engineError struct { + message string + code ErrorCode + cause error +} + +func (e *engineError) HasCode(code ErrorCode) bool { + if e.code == code { + return true + } + if cause := e.Unwrap(); cause != nil { + var causeE EngineError + if errors.As(cause, &causeE) { + return causeE.HasCode(code) + } + } + return false +} + +func (e *engineError) String() string { + return fmt.Sprintf("%s: %s", e.code, e.message) +} + +func (e *engineError) Error() string { + return e.message +} + +func (e *engineError) Code() ErrorCode { + return e.code +} + +func (e *engineError) Unwrap() error { + return e.cause +} + +func (e *engineError) CanAutoRetry() bool { + return e.code.CanAutoRetry() +} + +func newError(code ErrorCode, format string, args ...interface{}) EngineError { + return &engineError{ + message: fmt.Sprintf(format, args...), + code: code, + } +} + +// wrap wraps an error, adding an error code and message in the process. The wrapped error is added +// to the message automatically in Go style. If the passed error code is EUnidentified or not an EngineError +// this function will attempt to identify the error deeper. +func wrap(err error, code ErrorCode, format string, args ...interface{}) EngineError { + realArgs := append(args, err) + if code == EUnidentified { + var realErr EngineError + if errors.As(err, &realErr) { + code = realErr.Code() + } else { + if e := realIdentify(err); e != nil { + err = e + code = e.Code() + } + } + } + realMessage := fmt.Sprintf(fmt.Sprintf("%s (%v)", format, "(%v)"), realArgs...) + return &engineError{ + message: realMessage, + code: code, + cause: err, + } +} + +// identify attempts to identify the reason for the error and create a structure accordingly. If it fails to identify +// the reason it will return nil. +// +// Usage: +// +// if err != nil { +// if wrappedError := identify(err); wrappedError != nil { +// return wrappedError +// } +// // Handle unknown error here +// } +func identify(err error) error { + return realIdentify(err) +} + +func realIdentify(err error) EngineError { + var authErr *ovirtsdk.AuthError + var notFoundErr *ovirtsdk.NotFoundError + switch { + case errors.As(err, &authErr): + fallthrough + case strings.Contains(err.Error(), "access_denied"): + return wrap(err, EAccessDenied, "access denied, check your credentials") + case strings.Contains(err.Error(), "parse non-array sso with response"): + return wrap(err, + ENotAnOVirtEngine, "invalid credentials, or the URL does not point to an oVirt Engine, check your settings") + case strings.Contains(err.Error(), "server gave HTTP response to HTTPS client"): + return wrap(err, + ENotAnOVirtEngine, "the server gave a HTTP response to a HTTPS client, check if your URL is correct") + case strings.Contains(err.Error(), "tls"): + fallthrough + case strings.Contains(err.Error(), "x509"): + return wrap(err, ETLSError, "TLS error, check your CA certificate settings") + case errors.As(err, ¬FoundErr): + return wrap(err, ENotFound, "the requested resource was not found") + default: + return nil + } +} diff --git a/go.mod b/go.mod index f00e151..e9471f2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,6 @@ go 1.14 require ( github.com/google/uuid v1.3.0 - github.com/ovirt/go-ovirt v0.0.0-20210308100159-ac0bcbc88d7c + github.com/ovirt/go-ovirt v0.0.0-20210715091347-08d263954de7 github.com/stretchr/testify v1.7.0 // indirect ) diff --git a/go.sum b/go.sum index 57415bb..ad75453 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/ovirt/go-ovirt v0.0.0-20210308100159-ac0bcbc88d7c h1:2SbYZedeIawU8sGFnohfrdEcEMBA4V8SDouG6hly+H4= -github.com/ovirt/go-ovirt v0.0.0-20210308100159-ac0bcbc88d7c/go.mod h1:fLDxPk1Sf64DBYtwIYxrnx3gPZ1q0xPdWdI1y9vxUaw= +github.com/ovirt/go-ovirt v0.0.0-20210715091347-08d263954de7 h1:zaF8FIXz+hfsrDMhbYjZepy5BH90wfNufZGLHoOIjSk= +github.com/ovirt/go-ovirt v0.0.0-20210715091347-08d263954de7/go.mod h1:Zkdj9/rW6eyuw0uOeEns6O3pP5G2ak+bI/tgkQ/tEZI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/mock.go b/mock.go index 8c88bc1..7903169 100644 --- a/mock.go +++ b/mock.go @@ -2,7 +2,6 @@ package ovirtclient import ( "context" - "fmt" "sync" "github.com/google/uuid" @@ -53,17 +52,17 @@ func (m *mockClient) GetDisk(diskID string) (Disk, error) { if disk, ok := m.disks[diskID]; ok { return disk, nil } - return nil, fmt.Errorf("disk with ID %s not found", diskID) + return nil, newError(ENotFound, "disk with ID %s not found", diskID) } -func (m *mockClient) RemoveDisk(diskID string) error { +func (m *mockClient) RemoveDisk(_ context.Context, diskID string) error { m.lock.Lock() defer m.lock.Unlock() if _, ok := m.disks[diskID]; ok { delete(m.disks, diskID) return nil } - return fmt.Errorf("disk with ID %s not found", diskID) + return newError(ENotFound, "disk with ID %s not found", diskID) } func (m *mockClient) CreateVM( @@ -97,7 +96,7 @@ func (m *mockClient) GetCluster(id string) (Cluster, error) { if c, ok := m.clusters[id]; ok { return c, nil } - return nil, fmt.Errorf("cluster with ID %s not found", id) + return nil, newError(ENotFound, "cluster with ID %s not found", id) } func (m *mockClient) ListStorageDomains() ([]StorageDomain, error) { @@ -120,7 +119,7 @@ func (m *mockClient) GetStorageDomain(id string) (StorageDomain, error) { if s, ok := m.storageDomains[id]; ok { return s, nil } - return nil, fmt.Errorf("storage domain with ID %s not found", id) + return nil, newError(ENotFound, "storage domain with ID %s not found", id) } func (m *mockClient) ListHosts() ([]Host, error) { @@ -143,7 +142,7 @@ func (m *mockClient) GetHost(id string) (Host, error) { if h, ok := m.hosts[id]; ok { return h, nil } - return nil, fmt.Errorf("host with ID %s not found", id) + return nil, newError(ENotFound, "host with ID %s not found", id) } func (m *mockClient) ListTemplates() ([]Template, error) { @@ -166,5 +165,5 @@ func (m *mockClient) GetTemplate(id string) (Template, error) { if t, ok := m.templates[id]; ok { return t, nil } - return nil, fmt.Errorf("template with ID %s not found", id) + return nil, newError(ENotFound, "template with ID %s not found", id) } diff --git a/mock_disk_uploadimage.go b/mock_disk_uploadimage.go index 65464c1..b3e3ef1 100644 --- a/mock_disk_uploadimage.go +++ b/mock_disk_uploadimage.go @@ -19,10 +19,10 @@ func (m *mockClient) StartImageUpload( m.lock.Lock() defer m.lock.Unlock() if alias == "" { - return nil, fmt.Errorf("alias cannot be empty") + return nil, newError(EBadArgument, "alias cannot be empty") } if _, ok := m.storageDomains[storageDomainID]; !ok { - return nil, fmt.Errorf("storage domain with ID %s not found", storageDomainID) + return nil, newError(ENotFound, "storage domain with ID %s not found", storageDomainID) } bufReader := bufio.NewReaderSize(reader, qcowHeaderSize) diff --git a/new.go b/new.go index 3867dc5..36190b4 100644 --- a/new.go +++ b/new.go @@ -1,12 +1,14 @@ package ovirtclient import ( + "context" "crypto/tls" "crypto/x509" - "fmt" + "errors" "io/ioutil" "net/http" "strings" + "time" ovirtsdk4 "github.com/ovirt/go-ovirt" ) @@ -21,15 +23,31 @@ func New( insecure bool, extraHeaders map[string]string, logger Logger, +) (ClientWithLegacySupport, error) { + return NewWithVerify(url, username, password, caFile, caCert, insecure, extraHeaders, logger, testConnection) +} + +// NewWithVerify allows customizing the verification function for the connection. Alternatively, a nil can be passed to +// disable connection verification. +func NewWithVerify( + url string, + username string, + password string, + caFile string, + caCert []byte, + insecure bool, + extraHeaders map[string]string, + logger Logger, + verify func(connection *ovirtsdk4.Connection) error, ) (ClientWithLegacySupport, error) { if err := validateURL(url); err != nil { - return nil, fmt.Errorf("invalid URL: %s (%w)", url, err) + return nil, wrap(err, EBadArgument, "invalid URL: %s", url) } if err := validateUsername(username); err != nil { - return nil, fmt.Errorf("invalid username: %s (%w)", username, err) + return nil, wrap(err, "invalid username: %s", username) } if caFile == "" && len(caCert) == 0 && !insecure { - return nil, fmt.Errorf("one of caFile, caCert, or insecure must be provided") + return nil, newError(EBadArgument, "one of caFile, caCert, or insecure must be provided") } connBuilder := ovirtsdk4.NewConnectionBuilder(). @@ -45,12 +63,12 @@ func New( conn, err := connBuilder.Build() if err != nil { - return nil, fmt.Errorf("failed to create underlying oVirt connection (%w)", err) + return nil, wrap(err, EUnidentified, "failed to create underlying oVirt connection") } tlsConfig, err := createTLSConfig(caFile, caCert, insecure) if err != nil { - return nil, fmt.Errorf("failed to create TLS configuration (%w)", err) + return nil, wrap(err, ETLSError, "failed to create TLS configuration") } httpClient := http.Client{ @@ -59,6 +77,12 @@ func New( }, } + if verify != nil { + if err := verify(conn); err != nil { + return nil, err + } + } + return &oVirtClient{ conn: conn, httpClient: httpClient, @@ -67,6 +91,36 @@ func New( }, nil } +func testConnection(conn *ovirtsdk4.Connection) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + for { + lastError := conn.SystemService().Connection().Test() + if lastError == nil { + break + } + if err := identify(lastError); err != nil { + var realErr EngineError + // This will always be an engine error + _ = errors.As(err, &realErr) + if !realErr.CanAutoRetry() { + return err + } + lastError = err + } + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return wrap( + lastError, + ETimeout, + "timeout while attempting to create connection", + ) + } + } + return nil +} + func createTLSConfig( caFile string, caCert []byte, @@ -98,16 +152,17 @@ func createTLSConfig( } if len(caCert) != 0 { if ok := certPool.AppendCertsFromPEM(caCert); !ok { - return nil, fmt.Errorf("the provided CA certificate is not a valid certificate in PEM format") + return nil, newError(EBadArgument, "the provided CA certificate is not a valid certificate in PEM format") } } if caFile != "" { pemData, err := ioutil.ReadFile(caFile) if err != nil { - return nil, fmt.Errorf("failed to read CA certificate from file %s (%w)", caFile, err) + return nil, wrap(err, EFileReadFailed, "failed to read CA certificate from file %s", caFile) } if ok := certPool.AppendCertsFromPEM(pemData); !ok { - return nil, fmt.Errorf( + return nil, newError( + ETLSError, "the provided CA certificate is not a valid certificate in PEM format in file %s", caFile, ) @@ -121,20 +176,20 @@ func validateUsername(username string) error { usernameParts := strings.SplitN(username, "@", 2) //nolint:gomnd if len(usernameParts) != 2 { - return fmt.Errorf("username must contain exactly one @ sign (format should be admin@internal)") + return newError(EBadArgument, "username must contain exactly one @ sign (format should be admin@internal)") } if len(usernameParts[0]) == 0 { - return fmt.Errorf("no user supplied before @ sign in username (format should be admin@internal)") + return newError(EBadArgument, "no user supplied before @ sign in username (format should be admin@internal)") } if len(usernameParts[1]) == 0 { - return fmt.Errorf("no scope supplied after @ sign in username (format should be admin@internal)") + return newError(EBadArgument, "no scope supplied after @ sign in username (format should be admin@internal)") } return nil } func validateURL(url string) error { if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { - return fmt.Errorf("URL must start with http:// or https://") + return newError(EBadArgument, "URL must start with http:// or https://") } return nil } diff --git a/new_test.go b/new_test.go new file mode 100644 index 0000000..4dc9752 --- /dev/null +++ b/new_test.go @@ -0,0 +1,141 @@ +package ovirtclient_test + +import ( + "crypto/x509" + "errors" + "fmt" + "os" + "testing" + + ovirtclient "github.com/ovirt/go-ovirt-client" +) + +func TestInvalidCredentials(t *testing.T) { + url, caFile, caCert, insecure, err := getConnectionParametersForLiveTesting() + if err != nil { + t.Skipf("⚠ Skipping test: no live credentials provided.") + return + } + logger := ovirtclient.NewGoTestLogger(t) + _, err = ovirtclient.New( + url, + "nonexistent@internal", + "invalid-password-for-testing-purposes", + caFile, + []byte(caCert), + insecure, + nil, + logger, + ) + if err == nil { + t.Fatal("no error returned from New on invalid credentials") + } + + var e ovirtclient.EngineError + if errors.As(err, &e) { + if e.Code() != ovirtclient.EAccessDenied { + t.Fatalf("the returned error was not an access denied error EAccessDenied (%v)", err) + } + } else { + t.Fatalf("the returned error was not an EngineError (%v)", err) + } +} + +func TestBadURL(t *testing.T) { + logger := ovirtclient.NewGoTestLogger(t) + _, err := ovirtclient.New( + "https://example.com", + "nonexistent@internal", + "invalid-password-for-testing-purposes", + "", + nil, + true, + nil, + logger, + ) + if err == nil { + t.Fatal("no error returned from New on invalid URL") + } + + var e ovirtclient.EngineError + if errors.As(err, &e) { + if e.Code() != ovirtclient.ENotAnOVirtEngine { + t.Fatalf("the returned error was not an ENotAnOVirtEngine (%v)", err) + } + } else { + t.Fatalf("the returned error was not an EngineError (%v)", err) + } +} + +func TestBadTLS(t *testing.T) { + // False CA is the CA we will give to the client + _, _, falseCACertBytes, err := createCA() + if err != nil { + t.Fatalf("failed to create false CA (%v)", err) + } + + // Real CA is the CA we will use in the server + realCAPrivKey, realCACert, _, err := createCA() + if err != nil { + t.Fatalf("failed to create real CA (%v)", err) + } + + serverPrivKey, serverCert, err := createSignedCert( + []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + realCAPrivKey, + realCACert, + ) + if err != nil { + t.Fatalf("failed to create server certificate (%v)", err) + } + + port := getNextFreePort() + + srv, err := newTestServer(port, serverCert, serverPrivKey, &noopHandler{}) + if err != nil { + t.Fatal(err) + } + if err := srv.Start(); err != nil { + t.Fatal(err) + } + defer srv.Stop() + + logger := ovirtclient.NewGoTestLogger(t) + _, err = ovirtclient.New( + fmt.Sprintf("https://127.0.0.1:%d", port), + "nonexistent@internal", + "invalid-password-for-testing-purposes", + "", + falseCACertBytes, + false, + nil, + logger, + ) + + if err == nil { + t.Fatal("no error returned from New on invalid URL") + } + + var e ovirtclient.EngineError + if errors.As(err, &e) { + if e.Code() != ovirtclient.ETLSError { + t.Fatalf("the returned error was not an ETLSError (%v)", err) + } + } else { + t.Fatalf("the returned error was not an EngineError (%v)", err) + } +} + +func getConnectionParametersForLiveTesting() (string, string, string, bool, error) { + url := os.Getenv("OVIRT_URL") + if url == "" { + return "", "", "", false, fmt.Errorf("the OVIRT_URL environment variable must not be empty") + } + caFile := os.Getenv("OVIRT_CAFILE") + caCert := os.Getenv("OVIRT_CA_CERT") + insecure := os.Getenv("OVIRT_INSECURE") != "" + if caFile == "" && caCert == "" && !insecure { + return "", "", "", false, fmt.Errorf("one of OVIRT_CAFILE, OVIRT_CA_CERT, or OVIRT_INSECURE must be set") + } + return url, caFile, caCert, insecure, nil +} diff --git a/util_qcow.go b/util_qcow.go index 1e22a42..c502476 100644 --- a/util_qcow.go +++ b/util_qcow.go @@ -3,7 +3,6 @@ package ovirtclient import ( "bufio" "encoding/binary" - "fmt" ) func extractQCOWParameters(fileSize uint64, bufReader *bufio.Reader) ( @@ -15,7 +14,7 @@ func extractQCOWParameters(fileSize uint64, bufReader *bufio.Reader) ( qcowSize := fileSize header, err := bufReader.Peek(qcowHeaderSize) if err != nil { - return "", 0, fmt.Errorf("failed to read QCOW header (%w)", err) + return "", 0, wrap(err, EBadArgument, "failed to read QCOW header") } isQCOW := string(header[0:len(qcowMagicBytes)]) == qcowMagicBytes if !isQCOW { diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..75d1491 --- /dev/null +++ b/util_test.go @@ -0,0 +1,171 @@ +package ovirtclient_test + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "sync" + "time" +) + +var nextFreePort = 8080 +var nextFreePortLock = &sync.Mutex{} + +func getNextFreePort() int { + nextFreePortLock.Lock() + defer nextFreePortLock.Unlock() + + port := nextFreePort + nextFreePort++ + return port +} + +func createCA() (*rsa.PrivateKey, *x509.Certificate, []byte, error) { + ca := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"ACME, Inc"}, + Country: []string{"US"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(10, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + caPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create private key (%w)", err) + } + caCert, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to create CA certificate (%w)", err) + } + caPEM := new(bytes.Buffer) + if err := pem.Encode( + caPEM, + &pem.Block{ + Type: "CERTIFICATE", + Bytes: caCert, + }, + ); err != nil { + return nil, nil, nil, fmt.Errorf("failed to encode CA cert (%w)", err) + } + return caPrivateKey, ca, caPEM.Bytes(), nil +} + +func createSignedCert(usage []x509.ExtKeyUsage, caPrivateKey *rsa.PrivateKey, caCertificate *x509.Certificate) ( + []byte, + []byte, + error, +) { + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1658), + Subject: pkix.Name{ + Organization: []string{"ACME, Inc"}, + Country: []string{"US"}, + }, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(0, 0, 1), + SubjectKeyId: []byte{1}, + ExtKeyUsage: usage, + KeyUsage: x509.KeyUsageDigitalSignature, + } + certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + certBytes, err := x509.CreateCertificate( + rand.Reader, + cert, + caCertificate, + &certPrivKey.PublicKey, + caPrivateKey, + ) + if err != nil { + return nil, nil, err + } + certPrivKeyPEM := new(bytes.Buffer) + if err := pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }); err != nil { + return nil, nil, err + } + certPEM := new(bytes.Buffer) + if err := pem.Encode(certPEM, + &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}, + ); err != nil { + return nil, nil, err + } + return certPrivKeyPEM.Bytes(), certPEM.Bytes(), nil +} + +func newTestServer(port int, serverCert []byte, serverPrivKey []byte, handler http.Handler) (*testServer, error) { + cert, err := tls.X509KeyPair(serverCert, serverPrivKey) + if err != nil { + return nil, fmt.Errorf("failed to create key pair (%w)", err) + } + srv := &http.Server{ + Addr: fmt.Sprintf("127.0.0.1:%d", port), + Handler: handler, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{ + cert, + }, + }, + } + return &testServer{ + srv: srv, + port: port, + }, nil +} + +type testServer struct { + srv *http.Server + port int +} + +func (t *testServer) Start() error { + ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", t.port)) + if err != nil { + return fmt.Errorf("failed to start test server (%w)", err) + } + + go func() { + _ = t.srv.ServeTLS(ln, "", "") + }() + return nil +} + +func (t *testServer) Stop() { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + _ = t.srv.Shutdown(ctx) +} + +type noopHandler struct { +} + +func (t *noopHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) { + writer.WriteHeader(200) +} + +type unauthorizedHandler struct { +} + +func (u *unauthorizedHandler) ServeHTTP(writer http.ResponseWriter, _ *http.Request) { + writer.WriteHeader(401) + _, _ = writer.Write([]byte("ErrorUnauthorized")) +}