diff --git a/collator.go b/collator.go index 300bf49e..04b9a742 100644 --- a/collator.go +++ b/collator.go @@ -150,7 +150,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { inliner := NewInliner() for k, v := range rv.T.Components.Schemas { ref := "#/components/schemas/" + k - if current, ok := c.result.Components.Schemas[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Schemas[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Schemas[k] = v @@ -159,7 +159,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Parameters { ref := "#/components/parameters/" + k - if current, ok := c.result.Components.Parameters[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Parameters[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Parameters[k] = v @@ -168,7 +168,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Headers { ref := "#/components/headers/" + k - if current, ok := c.result.Components.Headers[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Headers[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Headers[k] = v @@ -177,7 +177,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.RequestBodies { ref := "#/components/requestBodies/" + k - if current, ok := c.result.Components.RequestBodies[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.RequestBodies[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.RequestBodies[k] = v @@ -186,7 +186,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Responses { ref := "#/components/responses/" + k - if current, ok := c.result.Components.Responses[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Responses[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Responses[k] = v @@ -195,7 +195,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.SecuritySchemes { ref := "#/components/securitySchemas/" + k - if current, ok := c.result.Components.SecuritySchemes[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.SecuritySchemes[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.SecuritySchemes[k] = v @@ -204,7 +204,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Examples { ref := "#/components/examples/" + k - if current, ok := c.result.Components.Examples[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Examples[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Examples[k] = v @@ -213,7 +213,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Links { ref := "#/components/links/" + k - if current, ok := c.result.Components.Links[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Links[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Links[k] = v @@ -222,7 +222,7 @@ func (c *Collator) mergeComponents(rv *ResourceVersion) error { } for k, v := range rv.T.Components.Callbacks { ref := "#/components/callbacks/" + k - if current, ok := c.result.Components.Callbacks[k]; ok && !componentsEqual(current, v) { + if current, ok := c.result.Components.Callbacks[k]; ok && !ComponentsEqual(current, v) { inliner.AddRef(ref) } else { c.result.Components.Callbacks[k] = v @@ -258,7 +258,7 @@ var cmpComponents = cmp.Options{ }, cmp.Ignore()), } -func componentsEqual(x, y interface{}) bool { +func ComponentsEqual(x, y interface{}) bool { return cmp.Equal(x, y, cmpComponents) } diff --git a/collator_test.go b/collator_test.go index 260a9fc2..a4188814 100644 --- a/collator_test.go +++ b/collator_test.go @@ -17,7 +17,8 @@ func TestRefRemover(t *testing.T) { errDoc := resp400.Value.Content["application/vnd.api+json"].Schema c.Assert(err, qt.IsNil) c.Assert("{\"$ref\":\"../errors.yaml#/ErrorDocument\"}", qt.JSONEquals, errDoc) - vervet.RemoveRefs(errDoc) + err = vervet.RemoveRefs(errDoc) + c.Assert(err, qt.IsNil) //nolint:lll // acked c.Assert("{\"additionalProperties\":false,\"example\":{\"errors\":[{\"detail\":\"Permission denied for this "+ "resource\",\"status\":\"403\"}],\"jsonapi\":{\"version\":\"1.0\"}},\"properties\":{\"errors\":{\"example\":"+ diff --git a/document.go b/document.go index c5e0d9c8..92eb9709 100644 --- a/document.go +++ b/document.go @@ -80,7 +80,10 @@ func NewDocumentFile(specFile string) (_ *Document, returnErr error) { if err != nil { return nil, err } - newRefAliasResolver(&t).resolve() + err = newRefAliasResolver(&t).resolve() + if err != nil { + return nil, err + } l := openapi3.NewLoader() l.IsExternalRefsAllowed = true diff --git a/go.mod b/go.mod index e69ad7b6..821448e3 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,6 @@ require ( github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 github.com/manifoldco/promptui v0.9.0 - github.com/mitchellh/reflectwalk v1.0.2 github.com/olekukonko/tablewriter v0.0.5 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.17.0 diff --git a/go.sum b/go.sum index fa895389..4aa682f9 100644 --- a/go.sum +++ b/go.sum @@ -353,8 +353,6 @@ github.com/maxbrunsfeld/counterfeiter/v6 v6.7.0 h1:z0CfPybq3CxaJvrrpf7Gme1psZTqH github.com/maxbrunsfeld/counterfeiter/v6 v6.7.0/go.mod h1:RVP6/F85JyxTrbJxWIdKU2vlSvK48iCMnMXRkSz7xtg= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= -github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/moby/moby v24.0.7+incompatible h1:RrVT5IXBn85mRtFKP+gFwVLCcnNPZIgN3NVRJG9Le+4= github.com/moby/moby v24.0.7+incompatible/go.mod h1:fDXVQ6+S340veQPv35CzDahGBmHsiclFwfEygB/TWMc= github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk= diff --git a/inliner.go b/inliner.go index 1b329472..e64ee63f 100644 --- a/inliner.go +++ b/inliner.go @@ -11,58 +11,67 @@ type Inliner struct { refs map[string]struct{} } -func (in *Inliner) ProcessCallbackRef(ref *openapi3.CallbackRef) { +func (in *Inliner) ProcessCallbackRef(ref *openapi3.CallbackRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessExampleRef(ref *openapi3.ExampleRef) { +func (in *Inliner) ProcessExampleRef(ref *openapi3.ExampleRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessHeaderRef(ref *openapi3.HeaderRef) { +func (in *Inliner) ProcessHeaderRef(ref *openapi3.HeaderRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessLinkRef(ref *openapi3.LinkRef) { +func (in *Inliner) ProcessLinkRef(ref *openapi3.LinkRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessParameterRef(ref *openapi3.ParameterRef) { +func (in *Inliner) ProcessParameterRef(ref *openapi3.ParameterRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) { +func (in *Inliner) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessResponseRef(ref *openapi3.ResponseRef) { +func (in *Inliner) ProcessResponseRef(ref *openapi3.ResponseRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessSchemaRef(ref *openapi3.SchemaRef) { +func (in *Inliner) ProcessSchemaRef(ref *openapi3.SchemaRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } -func (in *Inliner) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) { +func (in *Inliner) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) error { if in.matched(ref.Ref) { - RemoveRefs(ref) + return RemoveRefs(ref) } + return nil } // NewInliner returns a new Inliner instance. @@ -81,9 +90,8 @@ func (in *Inliner) Inline(doc *openapi3.T) error { if len(in.refs) == 0 { return nil } - openapiwalker.ProcessRefs(doc, in) - return nil + return openapiwalker.ProcessRefs(doc, in) } func (in *Inliner) matched(ref string) bool { @@ -95,45 +103,54 @@ func (in *Inliner) matched(ref string) bool { // fragment. If the reference has already been resolved, this has the effect of // "inlining" the formerly referenced object when serializing the OpenAPI // document. -func RemoveRefs(target interface{}) { - openapiwalker.ProcessRefs(target, clearRefs{}) +func RemoveRefs(target interface{}) error { + return openapiwalker.ProcessRefs(target, clearRefs{}) } type clearRefs struct { } -func (c clearRefs) ProcessCallbackRef(ref *openapi3.CallbackRef) { +func (c clearRefs) ProcessCallbackRef(ref *openapi3.CallbackRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessExampleRef(ref *openapi3.ExampleRef) { +func (c clearRefs) ProcessExampleRef(ref *openapi3.ExampleRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessHeaderRef(ref *openapi3.HeaderRef) { +func (c clearRefs) ProcessHeaderRef(ref *openapi3.HeaderRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessLinkRef(ref *openapi3.LinkRef) { +func (c clearRefs) ProcessLinkRef(ref *openapi3.LinkRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessParameterRef(ref *openapi3.ParameterRef) { +func (c clearRefs) ProcessParameterRef(ref *openapi3.ParameterRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) { +func (c clearRefs) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessResponseRef(ref *openapi3.ResponseRef) { +func (c clearRefs) ProcessResponseRef(ref *openapi3.ResponseRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessSchemaRef(ref *openapi3.SchemaRef) { +func (c clearRefs) ProcessSchemaRef(ref *openapi3.SchemaRef) error { ref.Ref = "" + return nil } -func (c clearRefs) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) { +func (c clearRefs) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) error { ref.Ref = "" + return nil } diff --git a/internal/openapiwalker/walker.go b/internal/openapiwalker/walker.go index a1a25787..702c5fec 100644 --- a/internal/openapiwalker/walker.go +++ b/internal/openapiwalker/walker.go @@ -7,265 +7,415 @@ import ( ) type RefProcessor interface { - ProcessCallbackRef(ref *openapi3.CallbackRef) - ProcessExampleRef(ref *openapi3.ExampleRef) - ProcessHeaderRef(ref *openapi3.HeaderRef) - ProcessLinkRef(ref *openapi3.LinkRef) - ProcessParameterRef(ref *openapi3.ParameterRef) - ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) - ProcessResponseRef(ref *openapi3.ResponseRef) - ProcessSchemaRef(ref *openapi3.SchemaRef) - ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) + ProcessCallbackRef(ref *openapi3.CallbackRef) error + ProcessExampleRef(ref *openapi3.ExampleRef) error + ProcessHeaderRef(ref *openapi3.HeaderRef) error + ProcessLinkRef(ref *openapi3.LinkRef) error + ProcessParameterRef(ref *openapi3.ParameterRef) error + ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) error + ProcessResponseRef(ref *openapi3.ResponseRef) error + ProcessSchemaRef(ref *openapi3.SchemaRef) error + ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) error } // ProcessRefs visits all the documents and calls the RefProcessor for each ref encountered. // //nolint:gocyclo // needs to check each type in the kinopneapi lib -func ProcessRefs(data any, p RefProcessor) { +func ProcessRefs(data any, p RefProcessor) error { switch v := data.(type) { case nil: - return + return nil case *openapi3.T: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.Components: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.MediaType: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.Response: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.Parameter: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.RequestBody: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.RequestBody: - ProcessRefs(v.Content, p) + return ProcessRefs(v.Content, p) case openapi3.T: - ProcessRefs(v.Components, p) - ProcessRefs(v.Info, p) - ProcessRefs(v.Paths, p) - ProcessRefs(v.Security, p) - ProcessRefs(v.Servers, p) - ProcessRefs(v.Tags, p) - ProcessRefs(v.ExternalDocs, p) + if err := ProcessRefs(v.Components, p); err != nil { + return err + } + if err := ProcessRefs(v.Info, p); err != nil { + return err + } + if err := ProcessRefs(v.Paths, p); err != nil { + return err + } + if err := ProcessRefs(v.Security, p); err != nil { + return err + } + if err := ProcessRefs(v.Servers, p); err != nil { + return err + } + if err := ProcessRefs(v.Tags, p); err != nil { + return err + } + if err := ProcessRefs(v.ExternalDocs, p); err != nil { + return err + } case openapi3.Components: - ProcessRefs(v.Schemas, p) - ProcessRefs(v.Parameters, p) - ProcessRefs(v.Headers, p) - ProcessRefs(v.RequestBodies, p) - ProcessRefs(v.Responses, p) - ProcessRefs(v.SecuritySchemes, p) - ProcessRefs(v.Examples, p) - ProcessRefs(v.Links, p) - ProcessRefs(v.Callbacks, p) + if err := ProcessRefs(v.Schemas, p); err != nil { + return err + } + if err := ProcessRefs(v.Parameters, p); err != nil { + return err + } + if err := ProcessRefs(v.Headers, p); err != nil { + return err + } + if err := ProcessRefs(v.RequestBodies, p); err != nil { + return err + } + if err := ProcessRefs(v.Responses, p); err != nil { + return err + } + if err := ProcessRefs(v.SecuritySchemes, p); err != nil { + return err + } + if err := ProcessRefs(v.Examples, p); err != nil { + return err + } + if err := ProcessRefs(v.Links, p); err != nil { + return err + } + if err := ProcessRefs(v.Callbacks, p); err != nil { + return err + } case openapi3.ResponseBodies: for _, ref := range v { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case openapi3.RequestBodies: for _, ref := range v { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case openapi3.SecurityRequirements: for _, requirement := range v { - ProcessRefs(requirement, p) + if err := ProcessRefs(requirement, p); err != nil { + return err + } } case openapi3.Response: - ProcessRefs(v.Headers, p) - ProcessRefs(v.Content, p) - ProcessRefs(v.Links, p) + if err := ProcessRefs(v.Headers, p); err != nil { + return err + } + if err := ProcessRefs(v.Content, p); err != nil { + return err + } + if err := ProcessRefs(v.Links, p); err != nil { + return err + } case openapi3.Links: for _, link := range v { - ProcessRefs(link, p) + if err := ProcessRefs(link, p); err != nil { + return err + } } case openapi3.Content: for _, mediaType := range v { - ProcessRefs(mediaType, p) + if err := ProcessRefs(mediaType, p); err != nil { + return err + } } case openapi3.ParametersMap: for _, ref := range v { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case openapi3.Schemas: for _, schema := range v { - ProcessRefs(schema, p) + if err := ProcessRefs(schema, p); err != nil { + return err + } } case openapi3.SchemaRefs: for _, schema := range v { - ProcessRefs(schema, p) + if err := ProcessRefs(schema, p); err != nil { + return err + } } case openapi3.Headers: for _, header := range v { - ProcessRefs(header, p) + if err := ProcessRefs(header, p); err != nil { + return err + } } case openapi3.MediaType: - ProcessRefs(v.Schema, p) - ProcessRefs(v.Examples, p) + if err := ProcessRefs(v.Schema, p); err != nil { + return err + } + if err := ProcessRefs(v.Examples, p); err != nil { + return err + } case openapi3.Parameter: - ProcessRefs(v.Schema, p) - ProcessRefs(v.Content, p) - ProcessRefs(v.Examples, p) + if err := ProcessRefs(v.Schema, p); err != nil { + return err + } + if err := ProcessRefs(v.Content, p); err != nil { + return err + } + if err := ProcessRefs(v.Examples, p); err != nil { + return err + } case openapi3.Examples: for _, example := range v { - ProcessRefs(example, p) + if err := ProcessRefs(example, p); err != nil { + return err + } } case *openapi3.Schema: if v != nil { - ProcessRefs(*v, p) + if err := ProcessRefs(*v, p); err != nil { + return err + } } case openapi3.SecuritySchemes: for _, ref := range v { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case openapi3.Callbacks: for _, ref := range v { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case *openapi3.Paths: if v != nil { - ProcessRefs(*v, p) + if err := ProcessRefs(*v, p); err != nil { + return err + } } case openapi3.Paths: for _, path := range v.Map() { - ProcessRefs(path, p) + if err := ProcessRefs(path, p); err != nil { + return err + } } case openapi3.Schema: - ProcessRefs(v.Properties, p) - ProcessRefs(v.Items, p) - ProcessRefs(v.AllOf, p) - ProcessRefs(v.AnyOf, p) - ProcessRefs(v.OneOf, p) - ProcessRefs(v.Not, p) + if err := ProcessRefs(v.Properties, p); err != nil { + return err + } + if err := ProcessRefs(v.Items, p); err != nil { + return err + } + if err := ProcessRefs(v.AllOf, p); err != nil { + return err + } + if err := ProcessRefs(v.AnyOf, p); err != nil { + return err + } + if err := ProcessRefs(v.OneOf, p); err != nil { + return err + } + if err := ProcessRefs(v.Not, p); err != nil { + return err + } case *openapi3.PathItem: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.PathItem: - ProcessRefs(v.Connect, p) - ProcessRefs(v.Delete, p) - ProcessRefs(v.Get, p) - ProcessRefs(v.Head, p) - ProcessRefs(v.Options, p) - ProcessRefs(v.Patch, p) - ProcessRefs(v.Post, p) - ProcessRefs(v.Put, p) - ProcessRefs(v.Trace, p) - ProcessRefs(v.Servers, p) - ProcessRefs(v.Parameters, p) + if err := ProcessRefs(v.Connect, p); err != nil { + return err + } + if err := ProcessRefs(v.Delete, p); err != nil { + return err + } + if err := ProcessRefs(v.Get, p); err != nil { + return err + } + if err := ProcessRefs(v.Head, p); err != nil { + return err + } + if err := ProcessRefs(v.Options, p); err != nil { + return err + } + if err := ProcessRefs(v.Patch, p); err != nil { + return err + } + if err := ProcessRefs(v.Post, p); err != nil { + return err + } + if err := ProcessRefs(v.Put, p); err != nil { + return err + } + if err := ProcessRefs(v.Trace, p); err != nil { + return err + } + if err := ProcessRefs(v.Servers, p); err != nil { + return err + } + if err := ProcessRefs(v.Parameters, p); err != nil { + return err + } case *openapi3.Operation: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.Operation: - ProcessRefs(v.Parameters, p) - ProcessRefs(v.RequestBody, p) - ProcessRefs(v.Responses, p) - ProcessRefs(v.Callbacks, p) - ProcessRefs(v.Security, p) - ProcessRefs(v.Servers, p) - ProcessRefs(v.ExternalDocs, p) + if err := ProcessRefs(v.Parameters, p); err != nil { + return err + } + if err := ProcessRefs(v.RequestBody, p); err != nil { + return err + } + if err := ProcessRefs(v.Responses, p); err != nil { + return err + } + if err := ProcessRefs(v.Callbacks, p); err != nil { + return err + } + if err := ProcessRefs(v.Security, p); err != nil { + return err + } + if err := ProcessRefs(v.Servers, p); err != nil { + return err + } + if err := ProcessRefs(v.ExternalDocs, p); err != nil { + return err + } case *openapi3.Responses: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.Responses: for _, ref := range v.Map() { - ProcessRefs(ref, p) + if err := ProcessRefs(ref, p); err != nil { + return err + } } case openapi3.Parameters: for _, parameter := range v { - ProcessRefs(parameter, p) + if err := ProcessRefs(parameter, p); err != nil { + return err + } } case *openapi3.Callback: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.Callback: for _, pathItem := range v.Map() { - ProcessRefs(pathItem, p) + if err := ProcessRefs(pathItem, p); err != nil { + return err + } } case *openapi3.Example: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case *openapi3.Header: if v != nil { - ProcessRefs(*v, p) + return ProcessRefs(*v, p) } case openapi3.Header: - ProcessRefs(v.Parameter, p) + return ProcessRefs(v.Parameter, p) case *openapi3.CallbackRef: if v != nil { - p.ProcessCallbackRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessCallbackRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.ExampleRef: if v != nil { - p.ProcessExampleRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessExampleRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.HeaderRef: if v != nil { - p.ProcessHeaderRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessHeaderRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.LinkRef: if v != nil { - p.ProcessLinkRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessLinkRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.ParameterRef: if v != nil { - p.ProcessParameterRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessParameterRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.RequestBodyRef: if v != nil { - p.ProcessRequestBodyRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessRequestBodyRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.ResponseRef: if v != nil { - p.ProcessResponseRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessResponseRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.SchemaRef: if v != nil { - p.ProcessSchemaRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessSchemaRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } case *openapi3.SecuritySchemeRef: if v != nil { - p.ProcessSecuritySchemeRef(v) - ProcessRefs(v.Value, p) + if err := p.ProcessSecuritySchemeRef(v); err != nil { + return err + } + return ProcessRefs(v.Value, p) } // no interesting nested fields case *openapi3.Info: @@ -290,4 +440,5 @@ func ProcessRefs(data any, p RefProcessor) { // be caught in tests panic(fmt.Sprintf("unhandled type %#v", v)) } + return nil } diff --git a/internal/simplebuild/build.go b/internal/simplebuild/build.go index 9b80bf2c..614082f6 100644 --- a/internal/simplebuild/build.go +++ b/internal/simplebuild/build.go @@ -21,6 +21,11 @@ func Build(ctx context.Context, project *config.Project, startDate vervet.Versio return nil } for _, apiConfig := range project.APIs { + if apiConfig.Output == nil { + fmt.Printf("No output specified for %s, skipping\n", apiConfig.Name) + continue + } + operations, err := LoadPaths(ctx, apiConfig) if err != nil { return err @@ -28,22 +33,32 @@ func Build(ctx context.Context, project *config.Project, startDate vervet.Versio for _, op := range operations { op.Annotate() } - - docs, err := operations.Build(startDate) + docs := operations.Build(startDate) + writer, err := NewWriter(*apiConfig.Output, appendOutputFiles) if err != nil { return err } - err = docs.ApplyOverlays(ctx, apiConfig.Overlays) - if err != nil { - return err - } + for _, doc := range docs { + err := doc.ApplyOverlays(ctx, apiConfig.Overlays) + if err != nil { + return err + } - if apiConfig.Output != nil { - err = docs.WriteOutputs(*apiConfig.Output, appendOutputFiles) + refResolver := NewRefResolver() + err = refResolver.ResolveRefs(doc.Doc) if err != nil { return err } + + err = writer.Write(doc) + if err != nil { + return err + } + } + err = writer.Finalize() + if err != nil { + return err } } return nil @@ -70,7 +85,7 @@ type VersionedDoc struct { } type DocSet []VersionedDoc -func (ops Operations) Build(startVersion vervet.Version) (DocSet, error) { +func (ops Operations) Build(startVersion vervet.Version) DocSet { versionDates := ops.VersionDates() versionDates = filterVersionByStartDate(versionDates, startVersion.Date) output := make(DocSet, len(versionDates)) @@ -79,20 +94,15 @@ func (ops Operations) Build(startVersion vervet.Version) (DocSet, error) { Doc: &openapi3.T{}, VersionDate: versionDate, } - refResolver := NewRefResolver(output[idx].Doc) for path, spec := range ops { op := spec.GetLatest(versionDate) if op == nil { continue } output[idx].Doc.AddOperation(path.Path, path.Method, op) - err := refResolver.Resolve(op) - if err != nil { - return nil, err - } } } - return output, nil + return output } func filterVersionByStartDate(dates []time.Time, startDate time.Time) []time.Time { diff --git a/internal/simplebuild/build_test.go b/internal/simplebuild/build_test.go index c8fc328d..37414e39 100644 --- a/internal/simplebuild/build_test.go +++ b/internal/simplebuild/build_test.go @@ -104,8 +104,7 @@ func TestBuild(t *testing.T) { ResourceName: "foo", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) c.Assert(output[0].VersionDate, qt.Equals, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) c.Assert(output[0].Doc.Paths.Value("/foo").Get, qt.IsNotNil) }) @@ -143,8 +142,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) c.Assert(output[0].VersionDate, qt.Equals, version.Date) c.Assert(output[0].Doc.Paths.Value("/foo").Get, qt.Equals, getFoo) c.Assert(output[0].Doc.Paths.Value("/foo").Post, qt.Equals, postFoo) @@ -179,8 +177,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) inputVersions := make([]time.Time, len(versions)) for idx, in := range versions { @@ -232,8 +229,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) slices.SortFunc(output, compareDocs) @@ -283,8 +279,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) slices.SortFunc(output, compareDocs) @@ -331,8 +326,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-02")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-02")) slices.SortFunc(output, compareDocs) @@ -382,8 +376,7 @@ func TestBuild(t *testing.T) { ResourceName: "bar", }}, } - output, err := ops.Build(vervet.MustParseVersion("2024-01-01")) - c.Assert(err, qt.IsNil) + output := ops.Build(vervet.MustParseVersion("2024-01-01")) slices.SortFunc(output, compareDocs) diff --git a/internal/simplebuild/output.go b/internal/simplebuild/output.go index 7b5cd7af..e709d206 100644 --- a/internal/simplebuild/output.go +++ b/internal/simplebuild/output.go @@ -17,110 +17,125 @@ import ( "github.com/snyk/vervet/v8/internal/files" ) -// Some services have a need to write specs to multiple destinations. This -// tends to happen in Typescript services in which we want to write specs to -// two places: -// - src/** for committing into git and ingesting into Backstage -// - dist/** for runtime module access to compiled specs. -// -// To maintain backwards compatibility we still allow a single path in the -// config file then normalise that here to an array. -func getOutputPaths(cfg config.Output) []string { - paths := cfg.Paths - if len(paths) == 0 && cfg.Path != "" { - paths = []string{cfg.Path} - } - return paths +type DocWriter struct { + cfg config.Output + paths []string + versionSpecFiles []string } -// WriteOutputs writes compiled specs to all directories specified by the given -// api config. Removes any existing builds if they are present. -func (docs DocSet) WriteOutputs(cfg config.Output, appendOutputFiles bool) error { +// NewWriter initialises any output paths, removing existing files and +// directories if they are present. +func NewWriter(cfg config.Output, appendOutputFiles bool) (*DocWriter, error) { paths := getOutputPaths(cfg) + toClear := paths + if appendOutputFiles { + // We treat the first path as the source of truth and copy the whole + // directory to the other paths in Finalize. + toClear = toClear[1:] + } - if !appendOutputFiles { - for _, dir := range paths { - err := os.RemoveAll(dir) - if err != nil { - return fmt.Errorf("clear output directory: %w", err) - } + for _, dir := range toClear { + err := os.RemoveAll(dir) + if err != nil { + return nil, fmt.Errorf("clear output directory: %w", err) } } - - err := docs.Write(paths[0], appendOutputFiles) + err := os.MkdirAll(paths[0], 0777) if err != nil { - return fmt.Errorf("write output files: %w", err) + return nil, fmt.Errorf("make output directory: %w", err) } - for _, dir := range paths[1:] { - err := files.CopyDir(dir, paths[0], true) - if err != nil { - return fmt.Errorf("copy outputs: %w", err) - } + versionSpecFiles, err := getExisingSpecFiles(paths[0]) + if err != nil { + return nil, fmt.Errorf("list existing files: %w", err) } - return nil + return &DocWriter{ + cfg: cfg, + paths: paths, + versionSpecFiles: versionSpecFiles, + }, nil } // Write writes compiled specs to a single directory in YAML and JSON formats. -// Unlike WriteOutputs this function assumes the destination directory does not -// already exist. -func (docs DocSet) Write(dir string, appendOutputFiles bool) error { - err := os.MkdirAll(dir, 0777) +// Call Finalize after to populate other directories. +func (out *DocWriter) Write(doc VersionedDoc) error { + // We write to the first directory then copy the entire directory + // afterwards + dir := out.paths[0] + + versionDir := path.Join(dir, doc.VersionDate.Format(time.DateOnly)) + err := os.MkdirAll(versionDir, 0755) if err != nil { - return err + return fmt.Errorf("make output directory: %w", err) + } + + jsonBuf, err := vervet.ToSpecJSON(doc.Doc) + if err != nil { + return fmt.Errorf("serialise spec to json: %w", err) + } + jsonSpecPath := path.Join(versionDir, "spec.json") + jsonEmbedPath, err := filepath.Rel(dir, jsonSpecPath) + if err != nil { + return fmt.Errorf("get relative output path: %w", err) } - existingFiles, err := getExisingSpecFiles(dir) + out.versionSpecFiles = append(out.versionSpecFiles, jsonEmbedPath) + err = os.WriteFile(jsonSpecPath, jsonBuf, 0644) if err != nil { - return fmt.Errorf("list existing files: %w", err) + return fmt.Errorf("write json file: %w", err) } + fmt.Println(jsonSpecPath) - versionSpecFiles := make([]string, 0, len(existingFiles)+len(docs)*2) - versionSpecFiles = append(versionSpecFiles, existingFiles...) - for _, doc := range docs { - versionDir := path.Join(dir, doc.VersionDate.Format(time.DateOnly)) - err = os.MkdirAll(versionDir, 0755) - if err != nil { - return fmt.Errorf("make output directory: %w", err) - } + yamlBuf, err := yaml.JSONToYAML(jsonBuf) + if err != nil { + return fmt.Errorf("convert spec to yaml: %w", err) + } + yamlBuf, err = vervet.WithGeneratedComment(yamlBuf) + if err != nil { + return fmt.Errorf("prepend yaml comment: %w", err) + } + yamlSpecPath := path.Join(versionDir, "spec.yaml") + yamlEmbedPath, err := filepath.Rel(dir, yamlSpecPath) + if err != nil { + return fmt.Errorf("get relative output path: %w", err) + } + out.versionSpecFiles = append(out.versionSpecFiles, yamlEmbedPath) + err = os.WriteFile(yamlSpecPath, yamlBuf, 0644) + if err != nil { + return fmt.Errorf("write yaml file: %w", err) + } + fmt.Println(yamlSpecPath) + return nil +} - jsonBuf, err := vervet.ToSpecJSON(doc.Doc) - if err != nil { - return fmt.Errorf("serialise spec to json: %w", err) - } - jsonSpecPath := path.Join(versionDir, "spec.json") - jsonEmbedPath, err := filepath.Rel(dir, jsonSpecPath) - if err != nil { - return fmt.Errorf("get relative output path: %w", err) - } - versionSpecFiles = append(versionSpecFiles, jsonEmbedPath) - err = os.WriteFile(jsonSpecPath, jsonBuf, 0644) +func (out *DocWriter) Finalize() error { + err := writeEmbedGo(out.paths[0], out.versionSpecFiles) + if err != nil { + return err + } + for _, dir := range out.paths[1:] { + err := files.CopyDir(dir, out.paths[0], true) if err != nil { - return fmt.Errorf("write json file: %w", err) + return fmt.Errorf("copy outputs: %w", err) } - fmt.Println(jsonSpecPath) + } + return nil +} - yamlBuf, err := yaml.JSONToYAML(jsonBuf) - if err != nil { - return fmt.Errorf("convert spec to yaml: %w", err) - } - yamlBuf, err = vervet.WithGeneratedComment(yamlBuf) - if err != nil { - return fmt.Errorf("prepend yaml comment: %w", err) - } - yamlSpecPath := path.Join(versionDir, "spec.yaml") - yamlEmbedPath, err := filepath.Rel(dir, yamlSpecPath) - if err != nil { - return fmt.Errorf("get relative output path: %w", err) - } - versionSpecFiles = append(versionSpecFiles, yamlEmbedPath) - err = os.WriteFile(yamlSpecPath, yamlBuf, 0644) - if err != nil { - return fmt.Errorf("write yaml file: %w", err) - } - fmt.Println(yamlSpecPath) +// Some services have a need to write specs to multiple destinations. This +// tends to happen in Typescript services in which we want to write specs to +// two places: +// - src/** for committing into git and ingesting into Backstage +// - dist/** for runtime module access to compiled specs. +// +// To maintain backwards compatibility we still allow a single path in the +// config file then normalise that here to an array. +func getOutputPaths(cfg config.Output) []string { + paths := cfg.Paths + if len(paths) == 0 && cfg.Path != "" { + paths = []string{cfg.Path} } - return writeEmbedGo(dir, versionSpecFiles) + return paths } func getExisingSpecFiles(dir string) ([]string, error) { diff --git a/internal/simplebuild/output_test.go b/internal/simplebuild/output_test.go index 45dfd08b..04d719d5 100644 --- a/internal/simplebuild/output_test.go +++ b/internal/simplebuild/output_test.go @@ -25,12 +25,11 @@ func TestDocSet_WriteOutputs(t *testing.T) { appendOutputFiles bool } tests := []struct { - name string - docs DocSet - args args - wantErr bool - assert func(*testing.T, args) - setup func(*testing.T, args) + name string + docs DocSet + args args + assert func(*testing.T, args) + setup func(*testing.T, args) }{ { name: "write the doc sets to outputs", @@ -45,7 +44,6 @@ func TestDocSet_WriteOutputs(t *testing.T) { Doc: testDoc, }, }, - wantErr: false, assert: func(t *testing.T, args args) { t.Helper() files, err := filepath.Glob(filepath.Join(args.cfg.Path, "*")) @@ -70,7 +68,6 @@ func TestDocSet_WriteOutputs(t *testing.T) { Doc: testDoc, }, }, - wantErr: false, setup: func(t *testing.T, args args) { t.Helper() err = os.WriteFile(path.Join(args.cfg.Path, "existing-file"), []byte("existing"), 0644) @@ -101,7 +98,6 @@ func TestDocSet_WriteOutputs(t *testing.T) { Doc: testDoc, }, }, - wantErr: false, setup: func(t *testing.T, args args) { t.Helper() err = os.WriteFile(path.Join(args.cfg.Path, "2024-02-01"), []byte("existing"), 0644) @@ -130,12 +126,17 @@ func TestDocSet_WriteOutputs(t *testing.T) { if tt.setup != nil { tt.setup(t, tt.args) } - if err := tt.docs.WriteOutputs(tt.args.cfg, tt.args.appendOutputFiles); (err != nil) != tt.wantErr { - t.Errorf("WriteOutputs() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.assert != nil { - tt.assert(t, tt.args) + + writer, err := NewWriter(tt.args.cfg, tt.args.appendOutputFiles) + c.Assert(err, qt.IsNil) + for _, doc := range tt.docs { + err = writer.Write(doc) + c.Assert(err, qt.IsNil) } + err = writer.Finalize() + c.Assert(err, qt.IsNil) + + tt.assert(t, tt.args) }) } } diff --git a/internal/simplebuild/overlays.go b/internal/simplebuild/overlays.go index b3e504de..6bc5ae97 100644 --- a/internal/simplebuild/overlays.go +++ b/internal/simplebuild/overlays.go @@ -11,18 +11,17 @@ import ( "github.com/snyk/vervet/v8/config" ) -func (docs DocSet) ApplyOverlays(ctx context.Context, cfgs []*config.Overlay) error { +func (doc VersionedDoc) ApplyOverlays(ctx context.Context, cfgs []*config.Overlay) error { + // TODO: cache overlays, err := loadOverlays(ctx, cfgs) if err != nil { return fmt.Errorf("load overlays: %w", err) } - for _, doc := range docs { - for _, overlay := range overlays { - // NB: Will overwrite any existing definitions without warning. - err := vervet.Merge(doc.Doc, overlay, true) - if err != nil { - return fmt.Errorf("apply overlay: %w", err) - } + for _, overlay := range overlays { + // NB: Will overwrite any existing definitions without warning. + err := vervet.Merge(doc.Doc, overlay, true) + if err != nil { + return fmt.Errorf("apply overlay: %w", err) } } diff --git a/internal/simplebuild/refs.go b/internal/simplebuild/refs.go index 7fd7b331..63fcfe12 100644 --- a/internal/simplebuild/refs.go +++ b/internal/simplebuild/refs.go @@ -3,10 +3,13 @@ package simplebuild import ( "fmt" "reflect" + "slices" "strings" "github.com/getkin/kin-openapi/openapi3" - "github.com/mitchellh/reflectwalk" + + "github.com/snyk/vervet/v8" + "github.com/snyk/vervet/v8/internal/openapiwalker" ) // Refs are an OpenAPI concept where you can define part of a spec then use a @@ -44,85 +47,247 @@ import ( // This class walks a given object and recursively copy any refs it finds back // into the document at the path they are referenced from. type refResolver struct { - doc *openapi3.T + doc *openapi3.T + renames map[string]string +} + +func NewRefResolver() refResolver { + return refResolver{renames: make(map[string]string)} +} + +func (rr *refResolver) copyToComponents(orignalRef string, component any) (string, error) { + newRef, err := rr.deref(orignalRef, reflect.ValueOf(component)) + if err != nil { + return "", err + } + if newRef != orignalRef { + rr.renames[newRef] = orignalRef + } + return newRef, nil +} + +func (rr *refResolver) ProcessCallbackRef(ref *openapi3.CallbackRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.CallbackRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err } -func NewRefResolver(doc *openapi3.T) refResolver { - return refResolver{doc: doc} +func (rr *refResolver) ProcessExampleRef(ref *openapi3.ExampleRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.ExampleRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err } -func (rr *refResolver) Resolve(from any) error { - return reflectwalk.Walk(from, rr) +func (rr *refResolver) ProcessHeaderRef(ref *openapi3.HeaderRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.HeaderRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err } -// Implements reflectwalk.StructWalker. This function is called for every -// struct found when walking. -func (rr *refResolver) Struct(v reflect.Value) error { - ref := v.FieldByName("Ref") - value := v.FieldByName("Value") - if !ref.IsValid() || !value.IsValid() { - // This isn't a openapi3.*Ref so nothing to do +func (rr *refResolver) ProcessLinkRef(ref *openapi3.LinkRef) error { + if ref.Ref == "" { return nil } - refLoc := ref.String() - if refLoc == "" { - // This ref has been inlined + component := &openapi3.LinkRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err +} + +func (rr *refResolver) ProcessParameterRef(ref *openapi3.ParameterRef) error { + if ref.Ref == "" { return nil } - // Create a new *Ref object to avoid mutating the original - derefed := reflect.New(v.Type()) - reflect.Indirect(derefed).FieldByName("Value").Set(value) + component := &openapi3.ParameterRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err +} - return rr.deref(refLoc, derefed) +func (rr *refResolver) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.RequestBodyRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err } -// Implements reflectwalk.StructWalker. We work on whole structs so there is -// nothing to do here. -func (rr *refResolver) StructField(sf reflect.StructField, v reflect.Value) error { - return nil +func (rr *refResolver) ProcessResponseRef(ref *openapi3.ResponseRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.ResponseRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err } -func (rr *refResolver) deref(ref string, value reflect.Value) error { +func (rr *refResolver) ProcessSchemaRef(ref *openapi3.SchemaRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.SchemaRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err +} + +func (rr *refResolver) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) error { + if ref.Ref == "" { + return nil + } + component := &openapi3.SecuritySchemeRef{ + Value: ref.Value, + } + var err error + ref.Ref, err = rr.copyToComponents(ref.Ref, component) + return err +} + +// ResolveRefs recursively finds all ref objects in the current documents paths +// and makes sure they are valid by copying the referenced component to the +// documents components section. +// +// WARNING: this will mutate references so if references are shared between +// documents make sure that any other documents are serialised before resolving +// refs. This method only ensures the current document is correct. +func (rr *refResolver) ResolveRefs(doc *openapi3.T) error { + // Refs use a full path eg #/components/schemas/..., to avoid having a + // special case at the top level we pass the entire document and trust the + // refs to not reference parts of the document they shouldn't. + rr.doc = doc + return openapiwalker.ProcessRefs(doc, rr) +} + +func (rr *refResolver) deref(ref string, value reflect.Value) (string, error) { path := strings.Split(ref, "/") if path[0] != "#" { // All refs should have been resolved to the local document when // loading so if we hit this case then we have not loaded the document // correctly. - return fmt.Errorf("external ref %s", ref) + return "", fmt.Errorf("external ref %s", ref) } + field := reflect.ValueOf(rr.doc) - // Need to forward declare err so field is not shadowed in the loop - var err error - for _, segment := range path[1:] { - // Maps are a special case since the key also needs to be created. - if field.Kind() == reflect.Map { - newValue := reflect.New(field.Type().Elem().Elem()) - field.SetMapIndex(reflect.ValueOf(segment), newValue) - field = newValue.Elem() - continue + newRef, err := deref(path[1:], field, value, rr.renames) + if err != nil { + return "", err + } + slices.Reverse(newRef) + newRefStr := fmt.Sprintf("#/%s", strings.Join(newRef, "/")) + return newRefStr, nil +} + +func deref(path []string, field, value reflect.Value, renames map[string]string) ([]string, error) { + if len(path) == 0 { + field.Set(value.Elem()) + return []string{}, nil + } + + newName := path[0] + nextField, err := getField(newName, field) + if err != nil { + return nil, fmt.Errorf("invalid ref: %w", err) + } + + // Lookup if we already have a component in the same document with the same + // name, if they conflict then we need to rename the current component + if len(path) == 1 { + // Name might have changed on previous documents but previous + // collisions are no longer present. Always start from the original + // name to make sure we aren't leaving unessisary gaps. + originalName, ok := renames[newName] + if ok { + newName = originalName + nextField, err = getField(newName, field) + if err != nil { + return nil, fmt.Errorf("invalid ref: %w", err) + } } - // else we assume we are working on a struct - field, err = getField(segment, field) - if err != nil { - return fmt.Errorf("invalid ref %s: %w", ref, err) + suffix := 0 + prevName := newName + // If the component is the same as the one we have already then it + // isn't a problem, we can merge them. + for !isZero(nextField) && !vervet.ComponentsEqual(nextField.Interface(), value.Interface()) { + newName = fmt.Sprintf("%s~%d", prevName, suffix) + nextField, err = getField(newName, field) + if err != nil { + return nil, fmt.Errorf("renaming ref: %w", err) + } + suffix += 1 } + } - // A lot of the openapi3.T fields are pointers so if this is the first - // time we have encountered an object of this type we need to create - // the container. - if field.Kind() == reflect.Map && field.IsZero() { - newValue := reflect.MakeMap(field.Type()) - field.Set(newValue) - } else if field.IsNil() { - newValue := reflect.New(field.Type().Elem()) - field.Set(newValue) + // If the container for the next layer doesn't exist then we have to create + // it. + if isZero(nextField) { + if field.Kind() == reflect.Map { + nextField = reflect.New(field.Type().Elem().Elem()) + field.SetMapIndex(reflect.ValueOf(newName), nextField) + } else { + var newValue reflect.Value + if nextField.Kind() == reflect.Map { + newValue = reflect.MakeMap(nextField.Type()) + } else { + newValue = reflect.New(nextField.Type().Elem()) + } + nextField.Set(newValue) } } - field.Set(value.Elem()) - return nil + if field.Kind() == reflect.Map { + nextField = nextField.Elem() + } + + newRef, err := deref(path[1:], nextField, value, renames) + return append(newRef, newName), err +} + +func isZero(field reflect.Value) bool { + if !field.IsValid() { + return true + } + if field.Kind() == reflect.Pointer { + return field.IsNil() + } + return field.IsZero() } func getField(tag string, object reflect.Value) (reflect.Value, error) { + if object.Kind() == reflect.Map { + fieldName := reflect.ValueOf(tag) + return object.MapIndex(fieldName), nil + } + reflectedObject := object.Type().Elem() if reflectedObject.Kind() != reflect.Struct { return reflect.Value{}, fmt.Errorf("object is not a struct") diff --git a/internal/simplebuild/refs_test.go b/internal/simplebuild/refs_test.go index bd41d69e..e150d7b8 100644 --- a/internal/simplebuild/refs_test.go +++ b/internal/simplebuild/refs_test.go @@ -1,6 +1,7 @@ package simplebuild_test import ( + "fmt" "testing" qt "github.com/frankban/quicktest" @@ -24,110 +25,122 @@ func TestResolveRefs(t *testing.T) { Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &path)), } - rr := simplebuild.NewRefResolver(&doc) - err := rr.Resolve(path) + rr := simplebuild.NewRefResolver() + err := rr.ResolveRefs(&doc) c.Assert(err, qt.IsNil) c.Assert(doc.Components.Parameters["foo"].Value, qt.Equals, param) }) - c.Run("ignores refs on other parts of the doc", func(c *qt.C) { - param := &openapi3.Parameter{} - pathA := openapi3.PathItem{ + c.Run("recursively resolves components", func(c *qt.C) { + schema := &openapi3.Schema{} + param := &openapi3.Parameter{ + Schema: &openapi3.SchemaRef{ + Ref: "#/components/schemas/foo", + Value: schema, + }, + } + path := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ Ref: "#/components/parameters/foo", Value: param, }}, } - pathB := openapi3.PathItem{ + doc := openapi3.T{ + Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &path)), + } + + rr := simplebuild.NewRefResolver() + err := rr.ResolveRefs(&doc) + c.Assert(err, qt.IsNil) + + c.Assert(doc.Components.Parameters["foo"].Value, qt.Equals, param) + c.Assert(doc.Components.Schemas["foo"].Value, qt.Equals, schema) + }) + + c.Run("ignores ref objects with no ref value", func(c *qt.C) { + param := &openapi3.Parameter{} + path := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ - Ref: "#/components/parameters/bar", Value: param, }}, } doc := openapi3.T{ - Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &pathA), openapi3.WithPath("/bar", &pathB)), + Components: &openapi3.Components{}, + Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &path)), } - rr := simplebuild.NewRefResolver(&doc) - err := rr.Resolve(pathA) + rr := simplebuild.NewRefResolver() + err := rr.ResolveRefs(&doc) c.Assert(err, qt.IsNil) - c.Assert(doc.Components.Parameters["bar"], qt.IsNil) + c.Assert(doc.Components.Parameters["foo"], qt.IsNil) }) - c.Run("merges refs from successive calls", func(c *qt.C) { - paramA := &openapi3.Parameter{} + c.Run("conflicting components get renamed", func(c *qt.C) { + paramA := &openapi3.Parameter{ + Name: "fooname", + } pathA := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ - Ref: "#/components/parameters/foo", + Ref: "#/components/parameters/fooo", Value: paramA, }}, } - paramB := &openapi3.Parameter{} + paramB := &openapi3.Parameter{ + Name: "barname", + } pathB := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ - Ref: "#/components/parameters/bar", + Ref: "#/components/parameters/fooo", Value: paramB, }}, } doc := openapi3.T{ - Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &pathA), openapi3.WithPath("/bar", &pathB)), } - rr := simplebuild.NewRefResolver(&doc) - err := rr.Resolve(pathA) - c.Assert(err, qt.IsNil) - err = rr.Resolve(pathB) + rr := simplebuild.NewRefResolver() + err := rr.ResolveRefs(&doc) c.Assert(err, qt.IsNil) - c.Assert(doc.Components.Parameters["foo"].Value, qt.Equals, paramA) - c.Assert(doc.Components.Parameters["bar"].Value, qt.Equals, paramB) + c.Assert(doc.Paths.Value("/foo").Parameters[0].Ref, qt.Not(qt.Equals), doc.Paths.Value("/bar").Parameters[0].Ref) + c.Assert(doc.Components.Parameters, qt.HasLen, 2) }) - c.Run("recursively resolves components", func(c *qt.C) { - schema := &openapi3.Schema{} - param := &openapi3.Parameter{ - Schema: &openapi3.SchemaRef{ - Ref: "#/components/schemas/foo", - Value: schema, - }, + c.Run("comparable components get merged", func(c *qt.C) { + paramA := &openapi3.Parameter{ + Name: "fooname", } - path := openapi3.PathItem{ + pathA := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ - Ref: "#/components/parameters/foo", - Value: param, + Ref: "#/components/parameters/fooo", + Value: paramA, }}, } - doc := openapi3.T{ - Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &path)), + paramB := &openapi3.Parameter{ + Name: "fooname", } - - rr := simplebuild.NewRefResolver(&doc) - err := rr.Resolve(path) - c.Assert(err, qt.IsNil) - - c.Assert(doc.Components.Parameters["foo"].Value, qt.Equals, param) - c.Assert(doc.Components.Schemas["foo"].Value, qt.Equals, schema) - }) - - c.Run("ignores ref objects with no ref value", func(c *qt.C) { - param := &openapi3.Parameter{} - path := openapi3.PathItem{ + pathB := openapi3.PathItem{ Parameters: []*openapi3.ParameterRef{{ - Value: param, + Ref: "#/components/parameters/fooo", + Value: paramB, }}, } doc := openapi3.T{ - Components: &openapi3.Components{}, - Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &path)), + Paths: openapi3.NewPaths(openapi3.WithPath("/foo", &pathA), openapi3.WithPath("/bar", &pathB)), } - rr := simplebuild.NewRefResolver(&doc) - err := rr.Resolve(path) + rr := simplebuild.NewRefResolver() + err := rr.ResolveRefs(&doc) c.Assert(err, qt.IsNil) - c.Assert(doc.Components.Parameters["foo"], qt.IsNil) + out, _ := doc.MarshalJSON() + fmt.Println() + fmt.Println(string(out)) + fmt.Println() + + c.Assert(doc.Paths.Value("/foo").Parameters[0].Ref, qt.Equals, doc.Paths.Value("/bar").Parameters[0].Ref) + c.Assert(doc.Components.Parameters, qt.HasLen, 1) }) } diff --git a/merge.go b/merge.go index 536e504d..df203417 100644 --- a/merge.go +++ b/merge.go @@ -142,7 +142,7 @@ func mergeComponents(dst, src *openapi3.T, replace bool) error { func mergeMap[T any](dst, src map[string]T, replace bool) error { for k, v := range src { existing, exists := dst[k] - if exists && !replace && !componentsEqual(v, existing) { + if exists && !replace && !ComponentsEqual(v, existing) { return errors.New("conflicting component: " + k) } dst[k] = v diff --git a/ref_alias_resolver.go b/ref_alias_resolver.go index ff7b27b2..5983dd56 100644 --- a/ref_alias_resolver.go +++ b/ref_alias_resolver.go @@ -15,40 +15,49 @@ type refAliasResolver struct { refAliases map[string]string } -func (l *refAliasResolver) ProcessCallbackRef(ref *openapi3.CallbackRef) { +func (l *refAliasResolver) ProcessCallbackRef(ref *openapi3.CallbackRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessExampleRef(ref *openapi3.ExampleRef) { +func (l *refAliasResolver) ProcessExampleRef(ref *openapi3.ExampleRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessHeaderRef(ref *openapi3.HeaderRef) { +func (l *refAliasResolver) ProcessHeaderRef(ref *openapi3.HeaderRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessLinkRef(ref *openapi3.LinkRef) { +func (l *refAliasResolver) ProcessLinkRef(ref *openapi3.LinkRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessParameterRef(ref *openapi3.ParameterRef) { +func (l *refAliasResolver) ProcessParameterRef(ref *openapi3.ParameterRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) { +func (l *refAliasResolver) ProcessRequestBodyRef(ref *openapi3.RequestBodyRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessResponseRef(ref *openapi3.ResponseRef) { +func (l *refAliasResolver) ProcessResponseRef(ref *openapi3.ResponseRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessSchemaRef(ref *openapi3.SchemaRef) { +func (l *refAliasResolver) ProcessSchemaRef(ref *openapi3.SchemaRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } -func (l *refAliasResolver) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) { +func (l *refAliasResolver) ProcessSecuritySchemeRef(ref *openapi3.SecuritySchemeRef) error { ref.Ref = l.resolveRefAlias(ref.Ref) + return nil } // newRefAliasResolver returns a new refAliasResolver. @@ -86,6 +95,6 @@ func (l *refAliasResolver) resolveRefAlias(ref string) string { } // resolve rewrites all references in the OpenAPI document to local references. -func (l *refAliasResolver) resolve() { - openapiwalker.ProcessRefs(l.doc, l) +func (l *refAliasResolver) resolve() error { + return openapiwalker.ProcessRefs(l.doc, l) }