Skip to content

Commit

Permalink
feat(openapi): 支持默认值
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed Nov 26, 2024
1 parent 1170b3d commit 6758bf0
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 40 deletions.
3 changes: 1 addition & 2 deletions openapi/build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package openapi

import (
"reflect"
"testing"

"github.com/issue9/assert/v4"
Expand Down Expand Up @@ -65,7 +64,7 @@ func TestComponents_build(t *testing.T) {
c.queries["q1"] = &Parameter{Name: "q1", Schema: &Schema{Type: TypeString}}
c.cookies["c1"] = &Parameter{Name: "c1", Schema: &Schema{Type: TypeNumber}}
c.headers["h1"] = &Parameter{Name: "h1", Schema: &Schema{Type: TypeBoolean}}
c.schemas["s1"] = NewSchema(reflect.TypeFor[int](), nil, nil)
c.schemas["s1"] = NewSchema(5, nil, nil)

r := c.build(p, d)
a.Equal(r.Parameters.Len(), 2).
Expand Down
38 changes: 22 additions & 16 deletions openapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,43 +119,46 @@ func (o *Operation) QueryRef(ref string, summary, description web.LocaleStringer
// QueryObject 从参数 o 中获取相应的查询参数
//
// 对于 obj 的要求与 [web.Context.QueryObject] 是相同的。
// 如果参数 obj 非空的,那么该非空字段同时也作为该查询参数的默认值。
// f 是对每个字段的修改,可以为空,其原型为
//
// func(p *Parameter)
//
// 可通过 p.Name 确定的参数名称
func (o *Operation) QueryObject(obj any, f func(*Parameter)) *Operation {
return o.queryObject(reflect.TypeOf(obj), f)
return o.queryObject(reflect.ValueOf(obj), f)
}

func (o *Operation) queryObject(t reflect.Type, f func(*Parameter)) *Operation {
for t.Kind() == reflect.Pointer {
t = t.Elem()
func (o *Operation) queryObject(v reflect.Value, f func(*Parameter)) *Operation {
for v.Kind() == reflect.Pointer {
v = v.Elem()
}

if t.Kind() != reflect.Struct {
if v.Kind() != reflect.Struct {
panic("t 必须得是 struct 类型")
}

t := v.Type()
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
ft := t.Field(i)
vt := v.Field(i)

if field.Anonymous {
o.queryObject(field.Type, f)
if ft.Anonymous {
o.queryObject(vt, f)
continue
}

if !field.IsExported() {
if !ft.IsExported() {
continue
}
name, _, _ := getTagName(field, query.Tag)
name, _, _ := getTagName(ft, query.Tag)
if name == "" {
name = field.Name
name = ft.Name
}

var desc web.LocaleStringer
if field.Tag != "" {
if c := field.Tag.Get(CommentTag); c != "" {
if ft.Tag != "" {
if c := ft.Tag.Get(CommentTag); c != "" {
desc = web.Phrase(c)
}
}
Expand All @@ -172,7 +175,10 @@ func (o *Operation) queryObject(t reflect.Type, f func(*Parameter)) *Operation {
}

p.Schema = &Schema{}
schemaFromType(nil, field.Type, true, "", p.Schema)
if !vt.IsZero() {
p.Schema.Default = vt.Interface()
}
schemaFromType(nil, ft.Type, true, "", p.Schema)
if !p.Schema.isBasicType() {
panic("不支持复杂类型")
}
Expand Down Expand Up @@ -236,7 +242,7 @@ func (o *Operation) CookieRef(ref string, summary, description web.LocaleStringe
func (o *Operation) Body(body any, ignorable bool, desc web.LocaleStringer, f func(*Request)) *Operation {
req := &Request{
Ignorable: ignorable,
Body: o.d.newSchema(reflect.TypeOf(body)),
Body: o.d.newSchema(body),
Description: desc,
}
if f != nil {
Expand Down Expand Up @@ -271,7 +277,7 @@ func (o *Operation) Response(status string, resp any, desc web.LocaleStringer, f
r := &Response{Description: desc}

if resp != nil {
r.Body = o.d.newSchema(reflect.TypeOf(resp))
r.Body = o.d.newSchema(resp)
}

if f != nil {
Expand Down
3 changes: 2 additions & 1 deletion openapi/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestDocument_API(t *testing.T) {
m := d.API(func(o *Operation) {
o.Header("h1", TypeString, nil, nil).
Tag("tag1").
QueryObject(&q{}, nil).
QueryObject(&q{Q3: 5}, nil).
Path("p1", TypeInteger, web.Phrase("lang"), nil).
Body(&object{}, true, web.Phrase("lang"), nil).
Response("200", 5, web.Phrase("desc"), nil).
Expand All @@ -162,6 +162,7 @@ func TestDocument_API(t *testing.T) {
Equal(o.Description, web.Phrase("lang")).
Length(o.Paths, 0).
Length(o.Queries, 3).
Equal(o.Queries[2].Schema.Default, 5).
NotNil(o.RequestBody.Body.Type, TypeObject).
Length(d.paths["/path/{p1}/abc"].Paths, 1)

Expand Down
3 changes: 1 addition & 2 deletions openapi/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package openapi

import (
"fmt"
"reflect"
"strings"

"github.com/issue9/web"
Expand Down Expand Up @@ -70,7 +69,7 @@ func WithResponse(resp *Response, status ...string) Option {
func WithProblemResponse() Option {
return WithResponse(&Response{
Ref: &Ref{Ref: "problem"},
Body: NewSchema(reflect.TypeOf(web.Problem{}), web.Phrase("problem response schema"), web.Phrase("problem response schema desc")),
Body: NewSchema(web.Problem{}, web.Phrase("problem response schema"), web.Phrase("problem response schema desc")),
Problem: true,
Description: web.Phrase("problem response"),
}, "4XX", "5XX")
Expand Down
24 changes: 17 additions & 7 deletions openapi/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type Schema struct {
Minimum int
Maximum int
Enum []any
Default any
}

type properties = orderedmap.OrderedMap[string, *renderer[schemaRenderer]]
Expand All @@ -80,21 +81,30 @@ type schemaRenderer struct {
Minimum int `json:"minimum,omitempty" yaml:"minimum,omitempty"`
Maximum int `json:"maximum,omitempty" yaml:"maximum,omitempty"`
Enum []any `json:"enum,omitempty" yaml:"enum,omitempty"`
Default any `json:"default,omitempty" yaml:"default,omitempty"`
}

func (d *Document) newSchema(t reflect.Type) *Schema {
s := &Schema{}
schemaFromType(d, t, true, "", s)
return s
func (d *Document) newSchema(v any) *Schema { return newSchema(d, v, nil, nil) }

// NewSchema 根据 v 生成 [Schema] 对象
//
// 如果 v 不是空值,那么 v 也将同时作为默认值出现在 [Schema] 中。
func NewSchema(v any, title, desc web.LocaleStringer) *Schema {
return newSchema(nil, v, title, desc)
}

// NewSchema 根据 [reflect.Type] 生成 [Schema] 对象
func NewSchema(t reflect.Type, title, desc web.LocaleStringer) *Schema {
func newSchema(d *Document, v any, title, desc web.LocaleStringer) *Schema {
s := &Schema{
Title: title,
Description: desc,
}
schemaFromType(nil, t, true, "", s)

rv := reflect.ValueOf(v)
if !rv.IsZero() {
s.Default = v
}

schemaFromType(d, rv.Type(), true, "", s)
return s
}

Expand Down
26 changes: 14 additions & 12 deletions openapi/schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package openapi

import (
"reflect"
"testing"
"time"

Expand All @@ -32,21 +31,23 @@ func TestDocument_newSchema(t *testing.T) {
ss := newServer(a)
d := New(ss, web.Phrase("desc"))

s := d.newSchema(reflect.TypeFor[int]())
s := d.newSchema(5)
a.Equal(s.Type, TypeInteger).
Nil(s.Ref)
Nil(s.Ref).
Equal(s.Default, 5)

s = d.newSchema(reflect.TypeFor[[]int]())
s = d.newSchema([]int{5, 6})
a.Equal(s.Type, TypeArray).
Equal(s.Items.Type, TypeInteger)
Equal(s.Items.Type, TypeInteger).
Equal(s.Default, []int{5, 6})

s = d.newSchema(reflect.TypeFor[map[string]float32]())
s = d.newSchema(map[string]float32{"1": 3.2})
a.Equal(s.Type, TypeObject).
Nil(s.Ref).
NotNil(s.AdditionalProperties).
Equal(s.AdditionalProperties.Type, TypeNumber)

s = d.newSchema(reflect.ValueOf(&object{}).Type())
s = d.newSchema(&object{})
a.Equal(s.Type, TypeObject).
NotZero(s.Ref.Ref).
Length(s.Properties, 3).
Expand All @@ -56,9 +57,10 @@ func TestDocument_newSchema(t *testing.T) {
Equal(s.Properties["Items"].Type, TypeArray).
NotZero(s.Properties["Items"].Items.Ref.Ref) // 引用了 object

s = d.newSchema(reflect.ValueOf(schemaObject1{}).Type())
s = d.newSchema(schemaObject1{})
a.Equal(s.Type, TypeObject).
NotZero(s.Ref.Ref).
Nil(s.Default).
Length(s.Properties, 7).
Equal(s.Properties["id"].Type, TypeInteger).
Equal(s.Properties["Root"].Type, TypeString).
Expand All @@ -68,7 +70,7 @@ func TestDocument_newSchema(t *testing.T) {
Equal(s.Properties["Z"].Type, TypeString).
Equal(s.Properties["Z"].Format, FormatDate)

s = d.newSchema(reflect.ValueOf(schemaObject2{}).Type())
s = d.newSchema(schemaObject2{})
a.Equal(s.Type, TypeObject).
NotZero(s.Ref.Ref).
Length(s.Properties, 7).
Expand All @@ -81,12 +83,12 @@ func TestDocument_newSchema(t *testing.T) {
func TestSchema_isBasicType(t *testing.T) {
a := assert.New(t, false)

s := NewSchema(reflect.TypeFor[int](), nil, nil)
s := NewSchema(5, nil, nil)
a.True(s.isBasicType())

s = NewSchema(reflect.TypeFor[object](), nil, nil)
s = NewSchema(object{}, nil, nil)
a.False(s.isBasicType())

s = NewSchema(reflect.TypeFor[[]string](), nil, nil)
s = NewSchema([]string{}, nil, nil)
a.True(s.isBasicType())
}

0 comments on commit 6758bf0

Please sign in to comment.