Skip to content

Commit

Permalink
better ctrlc handling
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelvigee committed May 12, 2024
1 parent 20dd371 commit 8368b6f
Show file tree
Hide file tree
Showing 13 changed files with 233 additions and 100 deletions.
92 changes: 48 additions & 44 deletions bootstrap/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,76 @@ import (
"fmt"
"github.com/hephbuild/heph/log/log"
"github.com/hephbuild/heph/targetrun"
"github.com/hephbuild/heph/utils/xcontext"
"github.com/hephbuild/heph/worker2"
"go.uber.org/multierr"
"io"
"os"
"sync"
)

func printErrTargetFailed(err error) bool {
var lerr targetrun.TargetFailed
if errors.As(err, &lerr) {
log.Errorf("%v failed: %v", lerr.Target.Addr, lerr.Err)
func PrintHumanError(err error) {
errs := worker2.CollectRootErrors(err)
skippedCount := 0

logFile := lerr.LogFile
if logFile != "" {
info, _ := os.Stat(logFile)
if info != nil && info.Size() > 0 {
fmt.Fprintln(log.Writer())
f, err := os.Open(logFile)
if err == nil {
_, _ = io.Copy(log.Writer(), f)
f.Close()
fmt.Fprintln(log.Writer())
}
log.Errorf("The log file can be found at %v", logFile)
}
}
var contextCanceledOnce sync.Once

for _, err := range multierr.Errors(lerr.Err) {
log.Error(err)
logError := func(err error) {
var sigCause xcontext.SignalCause
if errors.As(err, &sigCause) {
contextCanceledOnce.Do(func() {
log.Error(sigCause)
})
return
}

return true
log.Error(err)
}

return false
}
for _, err := range errs {
var terr targetrun.TargetFailed
if errors.As(err, &terr) {
var sigCause xcontext.SignalCause
if errors.As(err, &sigCause) {
contextCanceledOnce.Do(func() {
log.Error(sigCause)
})
return
}

func PrintHumanError(err error) {
errs := worker2.CollectRootErrors(err)
skippedCount := 0
skipSpacing := true
log.Errorf("%v failed: %v", terr.Target.Addr, terr.Err)

separate := func() {
if skipSpacing {
skipSpacing = false
} else {
fmt.Fprintln(log.Writer())
}
}
logFile := terr.LogFile
if logFile != "" {
info, _ := os.Stat(logFile)
if info != nil && info.Size() > 0 {
fmt.Fprintln(log.Writer())
f, err := os.Open(logFile)
if err == nil {
_, _ = io.Copy(log.Writer(), f)
f.Close()
fmt.Fprintln(log.Writer())
}
log.Errorf("The log file can be found at %v", logFile)
}
}

for _, err := range multierr.Errors(terr.Err) {
logError(err)
}

for _, err := range errs {
if printErrTargetFailed(err) {
// Printed !
continue
}

var jerr worker2.Error
if errors.As(err, &jerr) && jerr.Skipped() {
skippedCount++
skipSpacing = true
log.Debugf("skipped: %v", jerr)
} else {
for _, err := range multierr.Errors(err) {
skipSpacing = true
separate()
log.Error(err)
}
continue
}

for _, err := range multierr.Errors(err) {
logError(err)
}
}

Expand Down
11 changes: 2 additions & 9 deletions cmd/heph/entrypoint.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package main

import (
"context"
"errors"
"github.com/hephbuild/heph/bootstrap"
"github.com/hephbuild/heph/utils/xcontext"
"github.com/hephbuild/heph/utils/xerrors"
"github.com/hephbuild/heph/utils/xrand"
"github.com/hephbuild/heph/vfssimple"
"os"
Expand All @@ -19,11 +18,6 @@ func execute() error {
err := rootCmd.ExecuteContext(ctx)
postRun(err)
if err != nil {
// Handle ctrlc gracefuly
if ctx.Err() != nil {
return context.Cause(ctx)
}

return err
}

Expand All @@ -35,8 +29,7 @@ func Execute() {

if err := execute(); err != nil {
exitCode := 1
var eerr bootstrap.ErrorWithExitCode
if errors.As(err, &eerr) {
if eerr, ok := xerrors.As[bootstrap.ErrorWithExitCode](err); ok {
exitCode = eerr.ExitCode
// This is required in case ErrorWithExitCode does not have an Err set, just an ExitCode
err = eerr.Err
Expand Down
4 changes: 4 additions & 0 deletions lcache/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ func UncompressedPathFromArtifact(ctx context.Context, target graph.Targeter, ar
}

if err != nil {
if cerr := ctx.Err(); cerr != nil {
return "", cerr
}

return "", fmt.Errorf("ungz: cp: %w", err)
}

Expand Down
3 changes: 3 additions & 0 deletions lcache/artifacts_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,9 @@ func GenArtifact(ctx context.Context, dir string, a ArtifactWithProducer, compre
}
return nil
}
if cerr := ctx.Err(); cerr != nil {
return cerr
}
return err
}

Expand Down
12 changes: 6 additions & 6 deletions sandbox/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,15 @@ func (c *Cmd) Run() error {
go c.watchSoftContext()

err := c.Cmd.Run()
if cerr := c.SoftContext.Err(); cerr != nil {
if err != nil {
err = fmt.Errorf("%v: %w", err, cerr)
} else {
err = cerr
if err != nil {
if cerr := c.SoftContext.Err(); cerr != nil {
return fmt.Errorf("%w: %w", cerr, err)
}

return err
}

return err
return nil
}

func Exec(cfg ExecConfig) *Cmd {
Expand Down
6 changes: 0 additions & 6 deletions targetrun/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,6 @@ func (e *Runner) Run(ctx context.Context, rr Request, iocfg sandbox.IOConfig, tr
return nil, nil
}

if cerr := ctx.Err(); cerr != nil {
if !errors.Is(err, cerr) {
err = fmt.Errorf("%w: %v", cerr, err)
}
}

err := fmt.Errorf("exec: %w", err)

if iocfg.Stdin == os.Stdin {
Expand Down
5 changes: 2 additions & 3 deletions utils/flock/flock.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package flock

import (
"errors"
"fmt"
"github.com/hephbuild/heph/utils/xerrors"
"golang.org/x/sys/unix"
"os"
"syscall"
Expand Down Expand Up @@ -62,8 +62,7 @@ func Flunlock(f *os.File) error {
}

func IsErrWouldBlock(err error) bool {
var errno unix.Errno
if ok := errors.As(err, &errno); ok && errno == unix.EWOULDBLOCK {
if errno, ok := xerrors.As[unix.Errno](err); ok && errno == unix.EWOULDBLOCK {
return true
}

Expand Down
27 changes: 24 additions & 3 deletions utils/tar/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,19 @@ func UntarPath(ctx context.Context, in, to string, o UntarOptions) (err error) {
return UntarContext(ctx, tarf, to, o)
}

func UntarContext(ctx context.Context, in io.ReadCloser, to string, o UntarOptions) (err error) {
func UntarContext(ctx context.Context, in io.ReadCloser, to string, o UntarOptions) error {
cancel := xio.ContextCloser(ctx, in)
defer cancel()

return Untar(in, to, o)
err := Untar(in, to, o)
if err != nil {
if cerr := ctx.Err(); cerr != nil {
return cerr
}
return err
}

return nil
}

// See https://unix.stackexchange.com/a/557487
Expand Down Expand Up @@ -294,6 +302,10 @@ func UntarList(ctx context.Context, in io.ReadCloser, listPath string, progresss
return nil
})
if err != nil {
if cerr := ctx.Err(); cerr != nil {
return nil, cerr
}

return nil, err
}

Expand All @@ -310,7 +322,16 @@ func WalkPath(ctx context.Context, path string, fs ...func(*tar.Header, *tar.Rea
cancel := xio.ContextCloser(ctx, tarf)
defer cancel()

return Walk(tarf, fs...)
err = Walk(tarf, fs...)
if err != nil {
if cerr := ctx.Err(); cerr != nil {
return cerr
}

return err
}

return nil
}

func Walk(tarf io.Reader, fs ...func(*tar.Header, *tar.Reader) error) error {
Expand Down
62 changes: 57 additions & 5 deletions utils/xcontext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package xcontext

import (
"context"
"errors"
"fmt"
"github.com/hephbuild/heph/log/log"
"github.com/hephbuild/heph/utils/ads"
Expand Down Expand Up @@ -43,11 +44,51 @@ func newSoftCancelState() *state {
return s
}

type causeErr struct {
err, cause error
}

func (c causeErr) Is(target error) bool {
return errors.Is(c.err, target) || errors.Is(c.cause, target)
}

func (c causeErr) As(target any) bool {
return errors.As(c.err, target) || errors.As(c.cause, target)
}

func (c causeErr) Error() string {
return c.cause.Error()
}

type causeCtx struct {
context.Context
}

func (c causeCtx) Err() error {
err := c.Context.Err()
if err == nil {
return nil
}

cause := context.Cause(c.Context)
if cause == nil || err == cause {
return err
}

return causeErr{err, cause}
}

func WithCancelCause(parent context.Context) (context.Context, CancelFunc) {
ctx, cancel := context.WithCancelCause(parent)

return causeCtx{ctx}, cancel
}

// New returns one context that will be canceled by soft cancel first, the second one will act as a force cancel
// both inherit values from their parents
func (a *state) New(parent context.Context) (context.Context, context.Context, CancelFunc) {
scctx, scancel := context.WithCancelCause(parent)
hcctx, hcancel := context.WithCancelCause(context.Background())
scctx, scancel := WithCancelCause(parent)
hcctx, hcancel := WithCancelCause(context.Background())

hctx := CancellableContext{
Parent: parent,
Expand Down Expand Up @@ -145,8 +186,19 @@ func Cancel(ctx context.Context) {
const stuckTimeout = 5 * time.Second
const forceTimeout = 1 * time.Second

type SignalCause struct {
sig os.Signal
}

func (s SignalCause) Error() string {
if s.sig == nil {
return "signal: !!missing!!"
}
return fmt.Sprintf("signal: %v", s.sig.String())
}

func BootstrapSoftCancel() (context.Context, CancelFunc) {
ctx, cancel := context.WithCancelCause(context.Background())
ctx, cancel := WithCancelCause(context.Background())

sigCh := make(chan os.Signal)
signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
Expand All @@ -155,7 +207,7 @@ func BootstrapSoftCancel() (context.Context, CancelFunc) {

go func() {
sig := <-sigCh
cancel(fmt.Errorf(sig.String()))
cancel(SignalCause{sig})
if sc.has() {
hardCanceled := false
go func() {
Expand All @@ -168,7 +220,7 @@ func BootstrapSoftCancel() (context.Context, CancelFunc) {
sig := <-sigCh
hardCanceled = true
log.Warnf("Forcing cancellation...")
sc.hardCancel(fmt.Errorf(sig.String()))
sc.hardCancel(SignalCause{sig})
select {
// Wait for soft cancel to all be unregistered, should be fast, unless something is stuck
case <-sc.wait():
Expand Down
Loading

0 comments on commit 8368b6f

Please sign in to comment.