Skip to content

Commit 2692f43

Browse files
authored
openapi3: allow YAML-marshaling invalid specs (getkin#977)
* openapi3: allow YAML-marshaling invalid specs Signed-off-by: Pierre Fenoll <[email protected]> * fixes Signed-off-by: Pierre Fenoll <[email protected]> --------- Signed-off-by: Pierre Fenoll <[email protected]>
1 parent 4144c56 commit 2692f43

8 files changed

+93
-6
lines changed

.github/docs/openapi3.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ type Info struct {
645645
func (info Info) MarshalJSON() ([]byte, error)
646646
MarshalJSON returns the JSON encoding of Info.
647647

648-
func (info Info) MarshalYAML() (any, error)
648+
func (info *Info) MarshalYAML() (any, error)
649649
MarshalYAML returns the YAML encoding of Info.
650650

651651
func (info *Info) UnmarshalJSON(data []byte) error

maps.sh

+7
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,16 @@ EOF
155155

156156

157157
maplike_UnMarsh() {
158+
if [[ "$type" != '*'* ]]; then
159+
echo "TODO: impl non-pointer receiver YAML Marshaler"
160+
exit 2
161+
fi
158162
cat <<EOF >>"$maplike"
159163
// MarshalYAML returns the YAML encoding of ${type#'*'}.
160164
func (${name} ${type}) MarshalYAML() (any, error) {
165+
if ${name} == nil {
166+
return nil, nil
167+
}
161168
m := make(map[string]any, ${name}.Len()+len(${name}.Extensions))
162169
for k, v := range ${name}.Extensions {
163170
m[k] = v

openapi3/info.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ func (info Info) MarshalJSON() ([]byte, error) {
2929
}
3030

3131
// MarshalYAML returns the YAML encoding of Info.
32-
func (info Info) MarshalYAML() (any, error) {
32+
func (info *Info) MarshalYAML() (any, error) {
33+
if info == nil {
34+
return nil, nil
35+
}
3336
m := make(map[string]any, 6+len(info.Extensions))
3437
for k, v := range info.Extensions {
3538
m[k] = v

openapi3/issue972_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package openapi3
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
"gopkg.in/yaml.v3"
8+
)
9+
10+
func TestIssue972(t *testing.T) {
11+
type testcase struct {
12+
spec string
13+
validationErrorContains string
14+
}
15+
16+
base := `
17+
openapi: 3.0.2
18+
paths: {}
19+
components: {}
20+
`
21+
22+
for _, tc := range []testcase{{
23+
spec: base,
24+
validationErrorContains: "invalid info: must be an object",
25+
}, {
26+
spec: base + `
27+
info:
28+
title: "Hello World REST APIs"
29+
version: "1.0"
30+
`,
31+
}, {
32+
spec: base + `
33+
info: null
34+
`,
35+
validationErrorContains: "invalid info: must be an object",
36+
}, {
37+
spec: base + `
38+
info: {}
39+
`,
40+
validationErrorContains: "invalid info: value of version must be a non-empty string",
41+
}, {
42+
spec: base + `
43+
info:
44+
title: "Hello World REST APIs"
45+
`,
46+
validationErrorContains: "invalid info: value of version must be a non-empty string",
47+
}} {
48+
t.Logf("spec: %s", tc.spec)
49+
50+
loader := &Loader{}
51+
doc, err := loader.LoadFromData([]byte(tc.spec))
52+
assert.NoError(t, err)
53+
assert.NotNil(t, doc)
54+
55+
err = doc.Validate(loader.Context)
56+
if e := tc.validationErrorContains; e != "" {
57+
assert.ErrorContains(t, err, e)
58+
} else {
59+
assert.NoError(t, err)
60+
}
61+
62+
txt, err := yaml.Marshal(doc)
63+
assert.NoError(t, err)
64+
assert.NotEmpty(t, txt)
65+
}
66+
}

openapi3/loader.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -1140,8 +1140,7 @@ func (loader *Loader) resolvePathItemRef(doc *T, pathItem *PathItem, documentPat
11401140
*pathItem = p
11411141
} else {
11421142
var resolved PathItem
1143-
doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved)
1144-
if err != nil {
1143+
if doc, documentPath, err = loader.resolveComponent(doc, ref, documentPath, &resolved); err != nil {
11451144
if err == errMUSTPathItem {
11461145
return nil
11471146
}

openapi3/maplike.go

+9
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ func (responses Responses) JSONLookup(token string) (any, error) {
7878

7979
// MarshalYAML returns the YAML encoding of Responses.
8080
func (responses *Responses) MarshalYAML() (any, error) {
81+
if responses == nil {
82+
return nil, nil
83+
}
8184
m := make(map[string]any, responses.Len()+len(responses.Extensions))
8285
for k, v := range responses.Extensions {
8386
m[k] = v
@@ -206,6 +209,9 @@ func (callback Callback) JSONLookup(token string) (any, error) {
206209

207210
// MarshalYAML returns the YAML encoding of Callback.
208211
func (callback *Callback) MarshalYAML() (any, error) {
212+
if callback == nil {
213+
return nil, nil
214+
}
209215
m := make(map[string]any, callback.Len()+len(callback.Extensions))
210216
for k, v := range callback.Extensions {
211217
m[k] = v
@@ -334,6 +340,9 @@ func (paths Paths) JSONLookup(token string) (any, error) {
334340

335341
// MarshalYAML returns the YAML encoding of Paths.
336342
func (paths *Paths) MarshalYAML() (any, error) {
343+
if paths == nil {
344+
return nil, nil
345+
}
337346
m := make(map[string]any, paths.Len()+len(paths.Extensions))
338347
for k, v := range paths.Extensions {
339348
m[k] = v

openapi3/openapi3.go

+3
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ func (doc *T) MarshalJSON() ([]byte, error) {
6666

6767
// MarshalYAML returns the YAML encoding of T.
6868
func (doc *T) MarshalYAML() (any, error) {
69+
if doc == nil {
70+
return nil, nil
71+
}
6972
m := make(map[string]any, 4+len(doc.Extensions))
7073
for k, v := range doc.Extensions {
7174
m[k] = v

openapi3/schema_formats_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,13 @@ func TestNumberFormats(t *testing.T) {
169169
}
170170
DefineNumberFormatValidator("lessThan10", NewCallbackValidator(func(value float64) error {
171171
if value >= 10 {
172-
return fmt.Errorf("not less than 10")
172+
return errors.New("not less than 10")
173173
}
174174
return nil
175175
}))
176176
DefineIntegerFormatValidator("odd", NewCallbackValidator(func(value int64) error {
177177
if value%2 == 0 {
178-
return fmt.Errorf("not odd")
178+
return errors.New("not odd")
179179
}
180180
return nil
181181
}))

0 commit comments

Comments
 (0)