Skip to content

Commit 752114b

Browse files
authored
Add options lambdaurl.WithDetectContentType and lambda.WithContextValue (#516)
1 parent 1dca084 commit 752114b

File tree

6 files changed

+252
-31
lines changed

6 files changed

+252
-31
lines changed

.github/workflows/tests.yml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
name: run tests
99
runs-on: ubuntu-latest
1010
strategy:
11+
fail-fast: false
1112
matrix:
1213
go:
1314
- "1.21"

lambda/handler.go

+22
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Handler interface {
2323
type handlerOptions struct {
2424
handlerFunc
2525
baseContext context.Context
26+
contextValues map[interface{}]interface{}
2627
jsonRequestUseNumber bool
2728
jsonRequestDisallowUnknownFields bool
2829
jsonResponseEscapeHTML bool
@@ -50,6 +51,23 @@ func WithContext(ctx context.Context) Option {
5051
})
5152
}
5253

54+
// WithContextValue adds a value to the handler context.
55+
// If a base context was set using WithContext, that base is used as the parent.
56+
//
57+
// Usage:
58+
//
59+
// lambda.StartWithOptions(
60+
// func (ctx context.Context) (string, error) {
61+
// return ctx.Value("foo"), nil
62+
// },
63+
// lambda.WithContextValue("foo", "bar")
64+
// )
65+
func WithContextValue(key interface{}, value interface{}) Option {
66+
return Option(func(h *handlerOptions) {
67+
h.contextValues[key] = value
68+
})
69+
}
70+
5371
// WithSetEscapeHTML sets the SetEscapeHTML argument on the underlying json encoder
5472
//
5573
// Usage:
@@ -211,13 +229,17 @@ func newHandler(handlerFunc interface{}, options ...Option) *handlerOptions {
211229
}
212230
h := &handlerOptions{
213231
baseContext: context.Background(),
232+
contextValues: map[interface{}]interface{}{},
214233
jsonResponseEscapeHTML: false,
215234
jsonResponseIndentPrefix: "",
216235
jsonResponseIndentValue: "",
217236
}
218237
for _, option := range options {
219238
option(h)
220239
}
240+
for k, v := range h.contextValues {
241+
h.baseContext = context.WithValue(h.baseContext, k, v)
242+
}
221243
if h.enableSIGTERM {
222244
enableSIGTERM(h.sigtermCallbacks)
223245
}

lambda/sigterm_test.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"os"
1010
"os/exec"
1111
"path"
12+
"strconv"
1213
"strings"
1314
"testing"
1415
"time"
@@ -17,10 +18,6 @@ import (
1718
"github.com/stretchr/testify/require"
1819
)
1920

20-
const (
21-
rieInvokeAPI = "http://localhost:8080/2015-03-31/functions/function/invocations"
22-
)
23-
2421
func TestEnableSigterm(t *testing.T) {
2522
if _, err := exec.LookPath("aws-lambda-rie"); err != nil {
2623
t.Skipf("%v - install from https://github.com/aws/aws-lambda-runtime-interface-emulator/", err)
@@ -34,6 +31,7 @@ func TestEnableSigterm(t *testing.T) {
3431
handlerBuild.Stdout = os.Stderr
3532
require.NoError(t, handlerBuild.Run())
3633

34+
portI := 0
3735
for name, opts := range map[string]struct {
3836
envVars []string
3937
assertLogs func(t *testing.T, logs string)
@@ -53,8 +51,12 @@ func TestEnableSigterm(t *testing.T) {
5351
},
5452
} {
5553
t.Run(name, func(t *testing.T) {
54+
portI += 1
55+
addr1 := "localhost:" + strconv.Itoa(8000+portI)
56+
addr2 := "localhost:" + strconv.Itoa(9000+portI)
57+
rieInvokeAPI := "http://" + addr1 + "/2015-03-31/functions/function/invocations"
5658
// run the runtime interface emulator, capture the logs for assertion
57-
cmd := exec.Command("aws-lambda-rie", "sigterm.handler")
59+
cmd := exec.Command("aws-lambda-rie", "--runtime-interface-emulator-address", addr1, "--runtime-api-address", addr2, "sigterm.handler")
5860
cmd.Env = append([]string{
5961
"PATH=" + testDir,
6062
"AWS_LAMBDA_FUNCTION_TIMEOUT=2",

lambdaurl/http_handler.go

+76-15
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,76 @@ import (
1818
"github.com/aws/aws-lambda-go/lambda"
1919
)
2020

21+
type detectContentTypeContextKey struct{}
22+
23+
// WithDetectContentType sets the behavior of content type detection when the Content-Type header is not already provided.
24+
// When true, the first Write call will pass the intial bytes to http.DetectContentType.
25+
// When false, and if no Content-Type is provided, no Content-Type will be sent back to Lambda,
26+
// and the Lambda Function URL will fallback to it's default.
27+
//
28+
// Note: The http.ResponseWriter passed to the handler is unbuffered.
29+
// This may result in different Content-Type headers in the Function URL response when compared to http.ListenAndServe.
30+
//
31+
// Usage:
32+
//
33+
// lambdaurl.Start(
34+
// http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
35+
// w.Write("<!DOCTYPE html><html></html>")
36+
// }),
37+
// lambdaurl.WithDetectContentType(true)
38+
// )
39+
func WithDetectContentType(detectContentType bool) lambda.Option {
40+
return lambda.WithContextValue(detectContentTypeContextKey{}, detectContentType)
41+
}
42+
2143
type httpResponseWriter struct {
44+
detectContentType bool
45+
header http.Header
46+
writer io.Writer
47+
once sync.Once
48+
ready chan<- header
49+
}
50+
51+
type header struct {
52+
code int
2253
header http.Header
23-
writer io.Writer
24-
once sync.Once
25-
status chan<- int
2654
}
2755

2856
func (w *httpResponseWriter) Header() http.Header {
57+
if w.header == nil {
58+
w.header = http.Header{}
59+
}
2960
return w.header
3061
}
3162

3263
func (w *httpResponseWriter) Write(p []byte) (int, error) {
33-
w.once.Do(func() { w.status <- http.StatusOK })
64+
w.writeHeader(http.StatusOK, p)
3465
return w.writer.Write(p)
3566
}
3667

3768
func (w *httpResponseWriter) WriteHeader(statusCode int) {
38-
w.once.Do(func() { w.status <- statusCode })
69+
w.writeHeader(statusCode, nil)
70+
}
71+
72+
func (w *httpResponseWriter) writeHeader(statusCode int, initialPayload []byte) {
73+
w.once.Do(func() {
74+
if w.detectContentType {
75+
if w.Header().Get("Content-Type") == "" {
76+
w.Header().Set("Content-Type", detectContentType(initialPayload))
77+
}
78+
}
79+
w.ready <- header{code: statusCode, header: w.header}
80+
})
81+
}
82+
83+
func detectContentType(p []byte) string {
84+
// http.DetectContentType returns "text/plain; charset=utf-8" for nil and zero-length byte slices.
85+
// This is a weird behavior, since otherwise it defaults to "application/octet-stream"! So we'll do that.
86+
// This differs from http.ListenAndServe, which set no Content-Type when the initial Flush body is empty.
87+
if len(p) == 0 {
88+
return "application/octet-stream"
89+
}
90+
return http.DetectContentType(p)
3991
}
4092

4193
type requestContextKey struct{}
@@ -46,11 +98,13 @@ func RequestFromContext(ctx context.Context) (*events.LambdaFunctionURLRequest,
4698
return req, ok
4799
}
48100

49-
// Wrap converts an http.Handler into a lambda request handler.
101+
// Wrap converts an http.Handler into a Lambda request handler.
102+
//
50103
// Only Lambda Function URLs configured with `InvokeMode: RESPONSE_STREAM` are supported with the returned handler.
51-
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`
104+
// The response body of the handler will conform to the content-type `application/vnd.awslambda.http-integration-response`.
52105
func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
53106
return func(ctx context.Context, request *events.LambdaFunctionURLRequest) (*events.LambdaFunctionURLStreamingResponse, error) {
107+
54108
var body io.Reader = strings.NewReader(request.Body)
55109
if request.IsBase64Encoded {
56110
body = base64.NewDecoder(base64.StdEncoding, body)
@@ -67,21 +121,28 @@ func Wrap(handler http.Handler) func(context.Context, *events.LambdaFunctionURLR
67121
for k, v := range request.Headers {
68122
httpRequest.Header.Add(k, v)
69123
}
70-
status := make(chan int) // Signals when it's OK to start returning the response body to Lambda
71-
header := http.Header{}
124+
125+
ready := make(chan header) // Signals when it's OK to start returning the response body to Lambda
72126
r, w := io.Pipe()
127+
responseWriter := &httpResponseWriter{writer: w, ready: ready}
128+
if detectContentType, ok := ctx.Value(detectContentTypeContextKey{}).(bool); ok {
129+
responseWriter.detectContentType = detectContentType
130+
}
73131
go func() {
74-
defer close(status)
132+
defer close(ready)
75133
defer w.Close() // TODO: recover and CloseWithError the any panic value once the runtime API client supports plumbing fatal errors through the reader
76-
handler.ServeHTTP(&httpResponseWriter{writer: w, header: header, status: status}, httpRequest)
134+
//nolint:errcheck
135+
defer responseWriter.Write(nil) // force default status, headers, content type detection, if none occured during the execution of the handler
136+
handler.ServeHTTP(responseWriter, httpRequest)
77137
}()
138+
header := <-ready
78139
response := &events.LambdaFunctionURLStreamingResponse{
79140
Body: r,
80-
StatusCode: <-status,
141+
StatusCode: header.code,
81142
}
82-
if len(header) > 0 {
83-
response.Headers = make(map[string]string, len(header))
84-
for k, v := range header {
143+
if len(header.header) > 0 {
144+
response.Headers = make(map[string]string, len(header.header))
145+
for k, v := range header.header {
85146
if k == "Set-Cookie" {
86147
response.Cookies = v
87148
} else {

0 commit comments

Comments
 (0)