Skip to content

Commit

Permalink
openapi3filter: Fix default arrays application for query parameters (g…
Browse files Browse the repository at this point in the history
…etkin#992)

Co-authored-by: Gildas Lebel <[email protected]>
  • Loading branch information
TheSadlig and GildasLebel authored Jul 27, 2024
1 parent cd0a337 commit cf9684e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 3 deletions.
25 changes: 24 additions & 1 deletion openapi3filter/validate_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"sort"

"github.com/getkin/kin-openapi/openapi3"
Expand Down Expand Up @@ -103,6 +104,28 @@ func ValidateRequest(ctx context.Context, input *RequestValidationInput) (err er
return
}

// appendToQueryValues adds to query parameters each value in the provided slice
func appendToQueryValues[T any](q url.Values, parameterName string, v []T) {
for _, i := range v {
q.Add(parameterName, fmt.Sprintf("%v", i))
}
}

// populateDefaultQueryParameters populates default values inside query parameters, while ensuring types are respected
func populateDefaultQueryParameters(q url.Values, parameterName string, value any) {
switch t := value.(type) {
case []string:
appendToQueryValues(q, parameterName, t)
case []float64:
appendToQueryValues(q, parameterName, t)
case []int:
appendToQueryValues(q, parameterName, t)
default:
q.Add(parameterName, fmt.Sprintf("%v", value))
}

}

// ValidateParameter validates a parameter's value by JSON schema.
// The function returns RequestError with a ParseError cause when unable to parse a value.
// The function returns RequestError with ErrInvalidRequired cause when a value of a required parameter is not defined.
Expand Down Expand Up @@ -156,7 +179,7 @@ func ValidateParameter(ctx context.Context, input *RequestValidationInput, param
// Next check `parameter.Required && !found` will catch this.
case openapi3.ParameterInQuery:
q := req.URL.Query()
q.Add(parameter.Name, fmt.Sprintf("%v", value))
populateDefaultQueryParameters(q, parameter.Name, value)
req.URL.RawQuery = q.Encode()
case openapi3.ParameterInHeader:
req.Header.Add(parameter.Name, fmt.Sprintf("%v", value))
Expand Down
93 changes: 91 additions & 2 deletions openapi3filter/validate_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,6 @@ func TestValidateQueryParams(t *testing.T) {
},
},
},
//
//
}

for _, tc := range testCases {
Expand Down Expand Up @@ -569,3 +567,94 @@ paths:
})
require.Error(t, err)
}

var (
StringArraySchemaWithDefault = &openapi3.SchemaRef{
Value: &openapi3.Schema{
Type: &openapi3.Types{"array"},
Items: stringSchema,
Default: []string{"A", "B", "C"},
},
}
FloatArraySchemaWithDefault = &openapi3.SchemaRef{
Value: &openapi3.Schema{
Type: &openapi3.Types{"array"},
Items: numberSchema,
Default: []float64{1.5, 2.5, 3.5},
},
}
)

func TestValidateRequestDefault(t *testing.T) {
type testCase struct {
name string
param *openapi3.Parameter
query string
wantQuery map[string][]string
wantHeader map[string]any
}

testCases := []testCase{
{
name: "String Array In Query",
param: &openapi3.Parameter{
Name: "param", In: "query", Style: "form", Explode: explode,
Schema: StringArraySchemaWithDefault,
},
wantQuery: map[string][]string{
"param": {
"A",
"B",
"C",
},
},
},
{
name: "Float Array In Query",
param: &openapi3.Parameter{
Name: "param", In: "query", Style: "form", Explode: explode,
Schema: FloatArraySchemaWithDefault,
},
wantQuery: map[string][]string{
"param": {
"1.5",
"2.5",
"3.5",
},
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
info := &openapi3.Info{
Title: "MyAPI",
Version: "0.1",
}
doc := &openapi3.T{OpenAPI: "3.0.0", Info: info, Paths: openapi3.NewPaths()}
op := &openapi3.Operation{
OperationID: "test",
Parameters: []*openapi3.ParameterRef{{Value: tc.param}},
Responses: openapi3.NewResponses(),
}
doc.AddOperation("/test", http.MethodGet, op)
err := doc.Validate(context.Background())
require.NoError(t, err)
router, err := legacyrouter.NewRouter(doc)
require.NoError(t, err)

req, err := http.NewRequest(http.MethodGet, "http://test.org/test?"+tc.query, nil)
route, pathParams, err := router.FindRoute(req)
require.NoError(t, err)

input := &RequestValidationInput{Request: req, PathParams: pathParams, Route: route}

err = ValidateParameter(context.Background(), input, tc.param)
require.NoError(t, err)

for k, v := range tc.wantQuery {
require.Equal(t, v, input.Request.URL.Query()[k])
}
})
}
}

0 comments on commit cf9684e

Please sign in to comment.