Skip to content

Commit 93199f7

Browse files
authored
Add WithEnableSIGTERM option (#457)
* WIP adaption of my scratch code * implement as a handler Option * fix linter errors * io -> ioutil to keep older CI running * add test case, attempt to configure github action to install runtime interface emulator * fix truncated url * -L * please the race detector * contsrain testcase to go 1.15+ * please the linter * -v the tests * add test variant that checks that sigterm isn't enabled by default * Update tests.yml
1 parent 8bc331d commit 93199f7

File tree

6 files changed

+308
-1
lines changed

6 files changed

+308
-1
lines changed

.github/workflows/tests.yml

+5-1
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,15 @@ jobs:
2525

2626
- run: go version
2727

28+
- name: install lambda runtime interface emulator
29+
run: curl -L -o /usr/local/bin/aws-lambda-rie https://github.com/aws/aws-lambda-runtime-interface-emulator/releases/latest/download/aws-lambda-rie-x86_64
30+
- run: chmod +x /usr/local/bin/aws-lambda-rie
31+
2832
- name: Check out code into the Go module directory
2933
uses: actions/checkout@v2
3034

3135
- name: go test
32-
run: go test -race -coverprofile=coverage.txt -covermode=atomic ./...
36+
run: go test -v -race -coverprofile=coverage.txt -covermode=atomic ./...
3337

3438
- name: Upload coverage to Codecov
3539
uses: codecov/codecov-action@v2

lambda/extensions_api_client.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package lambda
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
"net/http"
10+
)
11+
12+
const (
13+
headerExtensionName = "Lambda-Extension-Name"
14+
headerExtensionIdentifier = "Lambda-Extension-Identifier"
15+
extensionAPIVersion = "2020-01-01"
16+
)
17+
18+
type extensionAPIEventType string
19+
20+
const (
21+
extensionInvokeEvent extensionAPIEventType = "INVOKE" //nolint:deadcode,unused,varcheck
22+
extensionShutdownEvent extensionAPIEventType = "SHUTDOWN" //nolint:deadcode,unused,varcheck
23+
)
24+
25+
type extensionAPIClient struct {
26+
baseURL string
27+
httpClient *http.Client
28+
}
29+
30+
func newExtensionAPIClient(address string) *extensionAPIClient {
31+
client := &http.Client{
32+
Timeout: 0, // connections to the extensions API are never expected to time out
33+
}
34+
endpoint := "http://" + address + "/" + extensionAPIVersion + "/extension/"
35+
return &extensionAPIClient{
36+
baseURL: endpoint,
37+
httpClient: client,
38+
}
39+
}
40+
41+
func (c *extensionAPIClient) register(name string, events ...extensionAPIEventType) (string, error) {
42+
url := c.baseURL + "register"
43+
body, _ := json.Marshal(struct {
44+
Events []extensionAPIEventType `json:"events"`
45+
}{
46+
Events: events,
47+
})
48+
49+
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewReader(body))
50+
req.Header.Add(headerExtensionName, name)
51+
res, err := c.httpClient.Do(req)
52+
if err != nil {
53+
return "", fmt.Errorf("failed to register extension: %v", err)
54+
}
55+
defer res.Body.Close()
56+
_, _ = io.Copy(ioutil.Discard, res.Body)
57+
58+
if res.StatusCode != http.StatusOK {
59+
return "", fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
60+
}
61+
62+
return res.Header.Get(headerExtensionIdentifier), nil
63+
}
64+
65+
type extensionEventResponse struct {
66+
EventType extensionAPIEventType
67+
// ... the rest not implemented
68+
}
69+
70+
func (c *extensionAPIClient) next(id string) (response extensionEventResponse, err error) {
71+
url := c.baseURL + "event/next"
72+
73+
req, _ := http.NewRequest(http.MethodGet, url, nil)
74+
req.Header.Add(headerExtensionIdentifier, id)
75+
res, err := c.httpClient.Do(req)
76+
if err != nil {
77+
err = fmt.Errorf("failed to get extension event: %v", err)
78+
return
79+
}
80+
defer res.Body.Close()
81+
_, _ = io.Copy(ioutil.Discard, res.Body)
82+
83+
if res.StatusCode != http.StatusOK {
84+
err = fmt.Errorf("failed to register extension, got response status: %d %s", res.StatusCode, http.StatusText(res.StatusCode))
85+
return
86+
}
87+
88+
err = json.NewDecoder(res.Body).Decode(&response)
89+
return
90+
}

lambda/handler.go

+25
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ type handlerOptions struct {
2323
jsonResponseEscapeHTML bool
2424
jsonResponseIndentPrefix string
2525
jsonResponseIndentValue string
26+
enableSIGTERM bool
27+
sigtermCallbacks []func()
2628
}
2729

2830
type Option func(*handlerOptions)
@@ -73,6 +75,26 @@ func WithSetIndent(prefix, indent string) Option {
7375
})
7476
}
7577

78+
// WithEnableSIGTERM enables SIGTERM behavior within the Lambda platform on container spindown.
79+
// SIGKILL will occur ~500ms after SIGTERM.
80+
// Optionally, an array of callback functions to run on SIGTERM may be provided.
81+
//
82+
// Usage:
83+
// lambda.StartWithOptions(
84+
// func (event any) (any error) {
85+
// return event, nil
86+
// },
87+
// lambda.WithEnableSIGTERM(func() {
88+
// log.Print("function container shutting down...")
89+
// })
90+
// )
91+
func WithEnableSIGTERM(callbacks ...func()) Option {
92+
return Option(func(h *handlerOptions) {
93+
h.sigtermCallbacks = append(h.sigtermCallbacks, callbacks...)
94+
h.enableSIGTERM = true
95+
})
96+
}
97+
7698
func validateArguments(handler reflect.Type) (bool, error) {
7799
handlerTakesContext := false
78100
if handler.NumIn() > 2 {
@@ -139,6 +161,9 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
139161
for _, option := range options {
140162
option(h)
141163
}
164+
if h.enableSIGTERM {
165+
enableSIGTERM(h.sigtermCallbacks)
166+
}
142167
h.Handler = reflectHandler(handlerFunc, h)
143168
return h
144169
}

lambda/sigterm.go

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
3+
package lambda
4+
5+
import (
6+
"log"
7+
"os"
8+
"os/signal"
9+
"syscall"
10+
)
11+
12+
// enableSIGTERM configures an optional list of sigtermHandlers to run on process shutdown.
13+
// This non-default behavior is enabled within Lambda using the extensions API.
14+
func enableSIGTERM(sigtermHandlers []func()) {
15+
// for fun, we'll also optionally register SIGTERM handlers
16+
if len(sigtermHandlers) > 0 {
17+
signaled := make(chan os.Signal, 1)
18+
signal.Notify(signaled, syscall.SIGTERM)
19+
go func() {
20+
<-signaled
21+
for _, f := range sigtermHandlers {
22+
f()
23+
}
24+
}()
25+
}
26+
27+
// detect if we're actually running within Lambda
28+
endpoint := os.Getenv("AWS_LAMBDA_RUNTIME_API")
29+
if endpoint == "" {
30+
log.Print("WARNING! AWS_LAMBDA_RUNTIME_API environment variable not found. Skipping attempt to register internal extension...")
31+
return
32+
}
33+
34+
// Now to do the AWS Lambda specific stuff.
35+
// The default Lambda behavior is for functions to get SIGKILL at the end of lifetime, or after a timeout.
36+
// Any use of the Lambda extension register API enables SIGTERM to be sent to the function process before the SIGKILL.
37+
// We'll register an extension that does not listen for any lifecycle events named "GoLangEnableSIGTERM".
38+
// The API will respond with an ID we need to pass in future requests.
39+
client := newExtensionAPIClient(endpoint)
40+
id, err := client.register("GoLangEnableSIGTERM")
41+
if err != nil {
42+
log.Printf("WARNING! Failed to register internal extension! SIGTERM events may not be enabled! err: %v", err)
43+
return
44+
}
45+
46+
// We didn't actually register for any events, but we need to call /next anyways to let the API know we're done initalizing.
47+
// Because we didn't register for any events, /next will never return, so we'll do this in a go routine that is doomed to stay blocked.
48+
go func() {
49+
_, err := client.next(id)
50+
log.Printf("WARNING! Reached expected unreachable code! Extension /next call expected to block forever! err: %v", err)
51+
}()
52+
53+
}

lambda/sigterm_test.go

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//go:build go1.15
2+
// +build go1.15
3+
4+
package lambda
5+
6+
import (
7+
"io/ioutil"
8+
"net/http"
9+
"os"
10+
"os/exec"
11+
"path"
12+
"strings"
13+
"testing"
14+
"time"
15+
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
)
19+
20+
const (
21+
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
22+
)
23+
24+
func TestEnableSigterm(t *testing.T) {
25+
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
26+
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
27+
}
28+
29+
testDir := t.TempDir()
30+
31+
// compile our handler, it'll always run to timeout ensuring the SIGTERM is triggered by aws-lambda-rie
32+
handlerBuild := exec.Command("go", "build", "-o", path.Join(testDir, "sigterm.handler"), "./testdata/sigterm.go")
33+
handlerBuild.Stderr = os.Stderr
34+
handlerBuild.Stdout = os.Stderr
35+
require.NoError(t, handlerBuild.Run())
36+
37+
for name, opts := range map[string]struct {
38+
envVars []string
39+
assertLogs func(t *testing.T, logs string)
40+
}{
41+
"baseline": {
42+
assertLogs: func(t *testing.T, logs string) {
43+
assert.NotContains(t, logs, "Hello SIGTERM!")
44+
assert.NotContains(t, logs, "I've been TERMINATED!")
45+
},
46+
},
47+
"sigterm enabled": {
48+
envVars: []string{"ENABLE_SIGTERM=please"},
49+
assertLogs: func(t *testing.T, logs string) {
50+
assert.Contains(t, logs, "Hello SIGTERM!")
51+
assert.Contains(t, logs, "I've been TERMINATED!")
52+
},
53+
},
54+
} {
55+
t.Run(name, func(t *testing.T) {
56+
// run the runtime interface emulator, capture the logs for assertion
57+
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
58+
cmd.Env = append([]string{
59+
"PATH=" + testDir,
60+
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",
61+
}, opts.envVars...)
62+
cmd.Stderr = os.Stderr
63+
stdout, err := cmd.StdoutPipe()
64+
require.NoError(t, err)
65+
var logs string
66+
done := make(chan interface{}) // closed on completion of log flush
67+
go func() {
68+
logBytes, err := ioutil.ReadAll(stdout)
69+
require.NoError(t, err)
70+
logs = string(logBytes)
71+
close(done)
72+
}()
73+
require.NoError(t, cmd.Start())
74+
t.Cleanup(func() { _ = cmd.Process.Kill() })
75+
76+
// give a moment for the port to bind
77+
time.Sleep(500 * time.Millisecond)
78+
79+
client := &http.Client{Timeout: 5 * time.Second} // http client timeout to prevent case from hanging on aws-lambda-rie
80+
resp, err := client.Post(rieInvokeAPI, "application/json", strings.NewReader("{}"))
81+
require.NoError(t, err)
82+
defer resp.Body.Close()
83+
body, err := ioutil.ReadAll(resp.Body)
84+
assert.NoError(t, err)
85+
assert.Equal(t, string(body), "Task timed out after 2.00 seconds")
86+
87+
require.NoError(t, cmd.Process.Kill()) // now ensure the logs are drained
88+
<-done
89+
t.Logf("stdout:\n%s", logs)
90+
opts.assertLogs(t, logs)
91+
})
92+
}
93+
}

lambda/testdata/sigterm.go

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"os"
7+
"os/signal"
8+
"syscall"
9+
"time"
10+
11+
"github.com/aws/aws-lambda-go/lambda"
12+
)
13+
14+
func init() {
15+
// conventional SIGTERM callback
16+
signaled := make(chan os.Signal, 1)
17+
signal.Notify(signaled, syscall.SIGTERM)
18+
go func() {
19+
<-signaled
20+
fmt.Println("I've been TERMINATED!")
21+
}()
22+
23+
}
24+
25+
func main() {
26+
// lambda option to enable sigterm, plus optional extra sigterm callbacks
27+
sigtermOption := lambda.WithEnableSIGTERM(func() {
28+
fmt.Println("Hello SIGTERM!")
29+
})
30+
handlerOptions := []lambda.Option{}
31+
if os.Getenv("ENABLE_SIGTERM") != "" {
32+
handlerOptions = append(handlerOptions, sigtermOption)
33+
}
34+
lambda.StartWithOptions(
35+
func(ctx context.Context) {
36+
deadline, _ := ctx.Deadline()
37+
<-time.After(time.Until(deadline) + time.Second)
38+
panic("unreachable line reached!")
39+
},
40+
handlerOptions...,
41+
)
42+
}

0 commit comments

Comments
 (0)