Skip to content

Commit

Permalink
Add response type for Lambda Function URL Streaming Responses (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmoffatt authored Apr 11, 2023
1 parent 47e703d commit a660c21
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
74 changes: 74 additions & 0 deletions events/lambda_function_urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

package events

import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
)

// LambdaFunctionURLRequest contains data coming from the HTTP request to a Lambda Function URL.
type LambdaFunctionURLRequest struct {
Version string `json:"version"` // Version is expected to be `"2.0"`
Expand Down Expand Up @@ -59,3 +67,69 @@ type LambdaFunctionURLResponse struct {
IsBase64Encoded bool `json:"isBase64Encoded"`
Cookies []string `json:"cookies"`
}

// LambdaFunctionURLStreamingResponse models the response to a Lambda Function URL when InvokeMode is RESPONSE_STREAM.
// If the InvokeMode of the Function URL is BUFFERED (default), use LambdaFunctionURLResponse instead.
//
// Example:
//
// lambda.Start(func() (*events.LambdaFunctionURLStreamingResponse, error) {
// return &events.LambdaFunctionURLStreamingResponse{
// StatusCode: 200,
// Headers: map[string]string{
// "Content-Type": "text/html",
// },
// Body: strings.NewReader("<html><body>Hello World!</body></html>"),
// }, nil
// })
type LambdaFunctionURLStreamingResponse struct {
prelude *bytes.Buffer

StatusCode int
Headers map[string]string
Body io.Reader
Cookies []string
}

func (r *LambdaFunctionURLStreamingResponse) Read(p []byte) (n int, err error) {
if r.prelude == nil {
if r.StatusCode == 0 {
r.StatusCode = http.StatusOK
}
b, err := json.Marshal(struct {
StatusCode int `json:"statusCode"`
Headers map[string]string `json:"headers,omitempty"`
Cookies []string `json:"cookies,omitempty"`
}{
StatusCode: r.StatusCode,
Headers: r.Headers,
Cookies: r.Cookies,
})
if err != nil {
return 0, err
}
r.prelude = bytes.NewBuffer(append(b, 0, 0, 0, 0, 0, 0, 0, 0))
}
if r.prelude.Len() > 0 {
return r.prelude.Read(p)
}
if r.Body == nil {
return 0, io.EOF
}
return r.Body.Read(p)
}

func (r *LambdaFunctionURLStreamingResponse) Close() error {
if closer, ok := r.Body.(io.ReadCloser); ok {
return closer.Close()
}
return nil
}

func (r *LambdaFunctionURLStreamingResponse) MarshalJSON() ([]byte, error) {
return nil, errors.New("not json")
}

func (r *LambdaFunctionURLStreamingResponse) ContentType() string {
return "application/vnd.awslambda.http-integration-response"
}
92 changes: 92 additions & 0 deletions events/lambda_function_urls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ package events

import (
"encoding/json"
"errors"
"io/ioutil" //nolint: staticcheck
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestLambdaFunctionURLResponseMarshaling(t *testing.T) {
Expand Down Expand Up @@ -55,3 +59,91 @@ func TestLambdaFunctionURLRequestMarshaling(t *testing.T) {

assert.JSONEq(t, string(inputJSON), string(outputJSON))
}

func TestLambdaFunctionURLStreamingResponseMarshaling(t *testing.T) {
for _, test := range []struct {
name string
response *LambdaFunctionURLStreamingResponse
expectedHead string
expectedBody string
}{
{
"empty",
&LambdaFunctionURLStreamingResponse{},
`{"statusCode":200}`,
"",
},
{
"just the status code",
&LambdaFunctionURLStreamingResponse{
StatusCode: http.StatusTeapot,
},
`{"statusCode":418}`,
"",
},
{
"status and headers and cookies and body",
&LambdaFunctionURLStreamingResponse{
StatusCode: http.StatusTeapot,
Headers: map[string]string{"hello": "world"},
Cookies: []string{"cookies", "are", "yummy"},
Body: strings.NewReader(`<html>Hello Hello</html>`),
},
`{"statusCode":418, "headers":{"hello":"world"}, "cookies":["cookies","are","yummy"]}`,
`<html>Hello Hello</html>`,
},
} {
t.Run(test.name, func(t *testing.T) {
response, err := ioutil.ReadAll(test.response)
require.NoError(t, err)
sep := "\x00\x00\x00\x00\x00\x00\x00\x00"
responseParts := strings.Split(string(response), sep)
require.Len(t, responseParts, 2)
head := string(responseParts[0])
body := string(responseParts[1])
assert.JSONEq(t, test.expectedHead, head)
assert.Equal(t, test.expectedBody, body)
assert.NoError(t, test.response.Close())
})
}
}

type readCloser struct {
closed bool
err error
reader *strings.Reader
}

func (r *readCloser) Read(p []byte) (int, error) {
return r.reader.Read(p)
}

func (r *readCloser) Close() error {
r.closed = true
return r.err
}

func TestLambdaFunctionURLStreamingResponsePropogatesInnerClose(t *testing.T) {
for _, test := range []struct {
name string
closer *readCloser
err error
}{
{
"closer no err",
&readCloser{},
nil,
},
{
"closer with err",
&readCloser{err: errors.New("yolo")},
errors.New("yolo"),
},
} {
t.Run(test.name, func(t *testing.T) {
response := &LambdaFunctionURLStreamingResponse{Body: test.closer}
assert.Equal(t, test.err, response.Close())
assert.True(t, test.closer.closed)
})
}
}

0 comments on commit a660c21

Please sign in to comment.