-
Notifications
You must be signed in to change notification settings - Fork 558
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* WIP adaption of my scratch code * implement as a handler Option * fix linter errors * io -> ioutil to keep older CI running * add test case, attempt to configure github action to install runtime interface emulator * fix truncated url * -L * please the race detector * contsrain testcase to go 1.15+ * please the linter * -v the tests * add test variant that checks that sigterm isn't enabled by default * Update tests.yml
- Loading branch information
Showing
6 changed files
with
308 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
package lambda | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
"net/http" | ||
) | ||
|
||
const ( | ||
headerExtensionName = "Lambda-Extension-Name" | ||
headerExtensionIdentifier = "Lambda-Extension-Identifier" | ||
extensionAPIVersion = "2020-01-01" | ||
) | ||
|
||
type extensionAPIEventType string | ||
|
||
const ( | ||
extensionInvokeEvent extensionAPIEventType = "INVOKE" //nolint:deadcode,unused,varcheck | ||
extensionShutdownEvent extensionAPIEventType = "SHUTDOWN" //nolint:deadcode,unused,varcheck | ||
) | ||
|
||
type extensionAPIClient struct { | ||
baseURL string | ||
httpClient *http.Client | ||
} | ||
|
||
func newExtensionAPIClient(address string) *extensionAPIClient { | ||
client := &http.Client{ | ||
Timeout: 0, // connections to the extensions API are never expected to time out | ||
} | ||
endpoint := "http://" + address + "/" + extensionAPIVersion + "/extension/" | ||
return &extensionAPIClient{ | ||
baseURL: endpoint, | ||
httpClient: client, | ||
} | ||
} | ||
|
||
func (c *extensionAPIClient) register(name string, events ...extensionAPIEventType) (string, error) { | ||
url := c.baseURL + "register" | ||
body, _ := json.Marshal(struct { | ||
Events []extensionAPIEventType `json:"events"` | ||
}{ | ||
Events: events, | ||
}) | ||
|
||
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(body)) | ||
req.Header.Add(headerExtensionName, name) | ||
res, err := c.httpClient.Do(req) | ||
if err != nil { | ||
return "", fmt.Errorf("failed to register extension: %v", err) | ||
} | ||
defer res.Body.Close() | ||
_, _ = io.Copy(ioutil.Discard, res.Body) | ||
|
||
if res.StatusCode != http.StatusOK { | ||
return "", fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) | ||
} | ||
|
||
return res.Header.Get(headerExtensionIdentifier), nil | ||
} | ||
|
||
type extensionEventResponse struct { | ||
EventType extensionAPIEventType | ||
// ... the rest not implemented | ||
} | ||
|
||
func (c *extensionAPIClient) next(id string) (response extensionEventResponse, err error) { | ||
url := c.baseURL + "event/next" | ||
|
||
req, _ := http.NewRequest(http.MethodGet, url, nil) | ||
req.Header.Add(headerExtensionIdentifier, id) | ||
res, err := c.httpClient.Do(req) | ||
if err != nil { | ||
err = fmt.Errorf("failed to get extension event: %v", err) | ||
return | ||
} | ||
defer res.Body.Close() | ||
_, _ = io.Copy(ioutil.Discard, res.Body) | ||
|
||
if res.StatusCode != http.StatusOK { | ||
err = fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode)) | ||
return | ||
} | ||
|
||
err = json.NewDecoder(res.Body).Decode(&response) | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
|
||
package lambda | ||
|
||
import ( | ||
"log" | ||
"os" | ||
"os/signal" | ||
"syscall" | ||
) | ||
|
||
// enableSIGTERM configures an optional list of sigtermHandlers to run on process shutdown. | ||
// This non-default behavior is enabled within Lambda using the extensions API. | ||
func enableSIGTERM(sigtermHandlers []func()) { | ||
// for fun, we'll also optionally register SIGTERM handlers | ||
if len(sigtermHandlers) > 0 { | ||
signaled := make(chan os.Signal, 1) | ||
signal.Notify(signaled, syscall.SIGTERM) | ||
go func() { | ||
<-signaled | ||
for _, f := range sigtermHandlers { | ||
f() | ||
} | ||
}() | ||
} | ||
|
||
// detect if we're actually running within Lambda | ||
endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API") | ||
if endpoint == "" { | ||
log.Print("WARNING! AWS_LAMBDA_RUNTIME_API environment variable not found. Skipping attempt to register internal extension...") | ||
return | ||
} | ||
|
||
// Now to do the AWS Lambda specific stuff. | ||
// The default Lambda behavior is for functions to get SIGKILL at the end of lifetime, or after a timeout. | ||
// Any use of the Lambda extension register API enables SIGTERM to be sent to the function process before the SIGKILL. | ||
// We'll register an extension that does not listen for any lifecycle events named "GoLangEnableSIGTERM". | ||
// The API will respond with an ID we need to pass in future requests. | ||
client := newExtensionAPIClient(endpoint) | ||
id, err := client.register("GoLangEnableSIGTERM") | ||
if err != nil { | ||
log.Printf("WARNING! Failed to register internal extension! SIGTERM events may not be enabled! err: %v", err) | ||
return | ||
} | ||
|
||
// We didn't actually register for any events, but we need to call /next anyways to let the API know we're done initalizing. | ||
// Because we didn't register for any events, /next will never return, so we'll do this in a go routine that is doomed to stay blocked. | ||
go func() { | ||
_, err := client.next(id) | ||
log.Printf("WARNING! Reached expected unreachable code! Extension /next call expected to block forever! err: %v", err) | ||
}() | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
//go:build go1.15 | ||
// +build go1.15 | ||
|
||
package lambda | ||
|
||
import ( | ||
"io/ioutil" | ||
"net/http" | ||
"os" | ||
"os/exec" | ||
"path" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const ( | ||
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations" | ||
) | ||
|
||
func TestEnableSigterm(t *testing.T) { | ||
if _, err := exec.LookPath("aws-lambda-rie"); err != nil { | ||
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err) | ||
} | ||
|
||
testDir := t.TempDir() | ||
|
||
// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie | ||
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "sigterm.handler"), "./testdata/sigterm.go") | ||
handlerBuild.Stderr = os.Stderr | ||
handlerBuild.Stdout = os.Stderr | ||
require.NoError(t, handlerBuild.Run()) | ||
|
||
for name, opts := range map[string]struct { | ||
envVars []string | ||
assertLogs func(t *testing.T, logs string) | ||
}{ | ||
"baseline": { | ||
assertLogs: func(t *testing.T, logs string) { | ||
assert.NotContains(t, logs, "Hello SIGTERM!") | ||
assert.NotContains(t, logs, "I've been TERMINATED!") | ||
}, | ||
}, | ||
"sigterm enabled": { | ||
envVars: []string{"ENABLE_SIGTERM=please"}, | ||
assertLogs: func(t *testing.T, logs string) { | ||
assert.Contains(t, logs, "Hello SIGTERM!") | ||
assert.Contains(t, logs, "I've been TERMINATED!") | ||
}, | ||
}, | ||
} { | ||
t.Run(name, func(t *testing.T) { | ||
// run the runtime interface emulator, capture the logs for assertion | ||
cmd := exec.Command("aws-lambda-rie", "sigterm.handler") | ||
cmd.Env = append([]string{ | ||
"PATH=" + testDir, | ||
"AWS_LAMBDA_FUNCTION_TIMEOUT=2", | ||
}, opts.envVars...) | ||
cmd.Stderr = os.Stderr | ||
stdout, err := cmd.StdoutPipe() | ||
require.NoError(t, err) | ||
var logs string | ||
done := make(chan interface{}) // closed on completion of log flush | ||
go func() { | ||
logBytes, err := ioutil.ReadAll(stdout) | ||
require.NoError(t, err) | ||
logs = string(logBytes) | ||
close(done) | ||
}() | ||
require.NoError(t, cmd.Start()) | ||
t.Cleanup(func() { _ = cmd.Process.Kill() }) | ||
|
||
// give a moment for the port to bind | ||
time.Sleep(500 * time.Millisecond) | ||
|
||
client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie | ||
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}")) | ||
require.NoError(t, err) | ||
defer resp.Body.Close() | ||
body, err := ioutil.ReadAll(resp.Body) | ||
assert.NoError(t, err) | ||
assert.Equal(t, string(body), "Task timed out after 2.00 seconds") | ||
|
||
require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained | ||
<-done | ||
t.Logf("stdout:\n%s", logs) | ||
opts.assertLogs(t, logs) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"os" | ||
"os/signal" | ||
"syscall" | ||
"time" | ||
|
||
"github.com/aws/aws-lambda-go/lambda" | ||
) | ||
|
||
func init() { | ||
// conventional SIGTERM callback | ||
signaled := make(chan os.Signal, 1) | ||
signal.Notify(signaled, syscall.SIGTERM) | ||
go func() { | ||
<-signaled | ||
fmt.Println("I've been TERMINATED!") | ||
}() | ||
|
||
} | ||
|
||
func main() { | ||
// lambda option to enable sigterm, plus optional extra sigterm callbacks | ||
sigtermOption := lambda.WithEnableSIGTERM(func() { | ||
fmt.Println("Hello SIGTERM!") | ||
}) | ||
handlerOptions := []lambda.Option{} | ||
if os.Getenv("ENABLE_SIGTERM") != "" { | ||
handlerOptions = append(handlerOptions, sigtermOption) | ||
} | ||
lambda.StartWithOptions( | ||
func(ctx context.Context) { | ||
deadline, _ := ctx.Deadline() | ||
<-time.After(time.Until(deadline) + time.Second) | ||
panic("unreachable line reached!") | ||
}, | ||
handlerOptions..., | ||
) | ||
} |