diff --git a/README.md b/README.md index 7cc75e3..1341e7c 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ Easy, fast and type-safe dependency injection for Go. * [Installation](#installation) * [Building the Container](#building-the-container) * [Configuring Services](#configuring-services) + + [arguments](#arguments) + [error](#error) + [import](#import) + [interface](#interface) @@ -14,6 +15,9 @@ Easy, fast and type-safe dependency injection for Go. + [type](#type) * [Using Services](#using-services) * [Unit Testing](#unit-testing) + * [Practical Examples](#practical-examples) + + [Mocking the Clock](#mocking-the-clock) + + [Mocking Runtime Dependencies](#mocking-runtime-dependencies) ## Installation @@ -68,6 +72,14 @@ References to other services and variables will be substituted automatically: - `@{SendEmail}` will inject the service named `SendEmail`. - `${DB_PASS}` will inject the environment variable `DB_PASS`. +### arguments + +If `arguments` is provided the service will be turned into a `func` so it can be +used as a factory. + +There is a full example in +[Mocking Runtime Dependencies](#mocking-runtime-dependencies). + ### error If `returns` provides two arguments (where the second one is the error) you must @@ -180,7 +192,7 @@ func main() { should create a new container: ```go -container := &Container{} +container := NewContainer() ``` Unit tests can make any modifications to the new container, including overriding @@ -192,7 +204,7 @@ func TestCustomerWelcome_Welcome(t *testing.T) { emailer.On("Send", "bob@smith.com", "Welcome", "Hi, Bob!").Return(nil) - container := &Container{} + container := NewContainer() container.SendEmail = emailer welcomer := container.GetCustomerWelcome() @@ -201,3 +213,146 @@ func TestCustomerWelcome_Welcome(t *testing.T) { emailer.AssertExpectations(t) } ``` + +## Practical Examples + +### Mocking the Clock + +Code that relies on time needs to be deterministic to be testable. Extracting +the clock as a service allows the whole time environment to be predictable for +all services. It also has the added benefit that `Sleep()` is free when running +unit tests. + +Here is a service, `WhatsTheTime`, that needs to use the current time: + +```yml +services: + Clock: + interface: github.com/jonboulle/clockwork.Clock + returns: clockwork.NewRealClock() + + WhatsTheTime: + type: '*WhatsTheTime' + properties: + clock: '@{Clock}' +``` + +`WhatsTheTime` can now use this clock the same way you would use the `time` +package: + +```go +import ( + "github.com/jonboulle/clockwork" + "time" +) + +type WhatsTheTime struct { + clock clockwork.Clock +} + +func (t *WhatsTheTime) InRFC1123() string { + return t.clock.Now().Format(time.RFC1123) +} +``` + +The unit test can substitute a fake clock for all services: + +```go +func TestWhatsTheTime_InRFC1123(t *testing.T) { + container := NewContainer() + container.Clock = clockwork.NewFakeClock() + + actual := container.GetWhatsTheTime().InRFC1123() + assert.Equal(t, "Wed, 04 Apr 1984 00:00:00 UTC", actual) +} +``` + +### Mocking Runtime Dependencies + +One situation that is tricky to write tests for is when you have the +instantiation inside a service because it needs some runtime state. + +Let's say you have a HTTP client that signs a request before sending it. The +signer can only be instantiated with the request, so we can't use traditional +injection: + +```go +type HTTPSignerClient struct{} + +func (c *HTTPSignerClient) Do(req *http.Request) (*http.Response, error) { + signer := NewSigner(req) + req.Headers.Set("Authorization", signer.Auth()) + + return http.DefaultClient.Do(req) +} +``` + +The `Signer` is not deterministic because it relies on the time: + +```go +type Signer struct { + req *http.Request +} + +func NewSigner(req *http.Request) *Signer { + return &Signer{req: req} +} + +// Produces something like "Mon Jan 2 15:04:05 2006 POST" +func (signer *Signer) Auth() string { + return time.Now().Format(time.ANSIC) + " " + signer.req.Method +} +``` + +Unlike mocking the clock (as in the previous tutorial) this time we need to keep +the logic of the signer, but verify the URL path sent to the signer. Of course, +we could manipulate or entirely replace the signer as well. + +Services can have `arguments` which turns them into factories. For example: + +```yml +services: + Signer: + type: '*Signer' + scope: prototype # Create a new Signer each time + arguments: # Define the dependencies at runtime. + req: '*http.Request' + returns: NewSigner(req) # Setup code can reference the runtime dependencies. + + HTTPSignerClient: + type: '*HTTPSignerClient' + properties: + CreateSigner: '@{Signer}' # Looks like a regular service, right? +``` + +Dingo has transformed the service into a factory, using a function: + +```go +type HTTPSignerClient struct { + CreateSigner func(req *http.Request) *Signer +} + +func (c *HTTPSignerClient) Do(req *http.Request) (*http.Response, error) { + signer := c.CreateSigner(req) + req.Headers.Set("Authorization", signer.Auth()) + + return http.DefaultClient.Do(req) +} +``` + +Under test we can control this factory like any other service: + +```go +func TestHTTPSignerClient_Do(t *testing.T) { + container := NewContainer() + container.Signer = func(req *http.Request) *Signer { + assert.Equals(t, req.URL.Path, "/foo") + + return NewSigner(req) + } + + client := container.GetHTTPSignerClient() + _, err := client.Do(http.NewRequest("GET", "/foo", nil)) + assert.NoError(t, err) +} +``` diff --git a/arguments.go b/arguments.go new file mode 100644 index 0000000..58b236a --- /dev/null +++ b/arguments.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "sort" +) + +type Arguments map[string]Type + +// Names returns all of the argument names sorted. +func (args Arguments) Names() (names []string) { + for arg := range args { + names = append(names, arg) + } + + sort.Strings(names) + + return +} + +func (args Arguments) GoArguments() (ss []string) { + for _, argName := range args.Names() { + ss = append(ss, fmt.Sprintf("%s %s", argName, + args[argName].LocalEntityType())) + } + + return +} diff --git a/arguments_test.go b/arguments_test.go new file mode 100644 index 0000000..dafdbbd --- /dev/null +++ b/arguments_test.go @@ -0,0 +1,54 @@ +package main + +import ( + "github.com/elliotchance/testify-stats/assert" + "testing" +) + +var argumentTests = map[string]struct { + Arguments Arguments + Names []string + GoArguments []string +}{ + "Nil": { + Arguments: nil, + Names: nil, + GoArguments: nil, + }, + "Empty": { + Arguments: map[string]Type{}, + Names: nil, + GoArguments: nil, + }, + "One": { + Arguments: map[string]Type{"foo": "int"}, + Names: []string{"foo"}, + GoArguments: []string{"foo int"}, + }, + "ArgumentsAlwaysSortedByName": { + Arguments: map[string]Type{"foo": "int", "bar": "*float64"}, + Names: []string{"bar", "foo"}, + GoArguments: []string{"bar *float64", "foo int"}, + }, + "RemovePackageName": { + Arguments: map[string]Type{"req": "*net/http.Request"}, + Names: []string{"req"}, + GoArguments: []string{"req *http.Request"}, + }, +} + +func TestArguments_Names(t *testing.T) { + for testName, test := range argumentTests { + t.Run(testName, func(t *testing.T) { + assert.Equal(t, test.Names, test.Arguments.Names()) + }) + } +} + +func TestArguments_GoArguments(t *testing.T) { + for testName, test := range argumentTests { + t.Run(testName, func(t *testing.T) { + assert.Equal(t, test.GoArguments, test.Arguments.GoArguments()) + }) + } +} diff --git a/dingotest/dingo.go b/dingotest/dingo.go index e01cdea..f44b317 100644 --- a/dingotest/dingo.go +++ b/dingotest/dingo.go @@ -2,25 +2,52 @@ package dingotest import ( go_sub_pkg "github.com/elliotchance/dingo/dingotest/go-sub-pkg" + "github.com/jonboulle/clockwork" + "net/http" "os" - time "time" + "time" ) type Container struct { - AFunc func(int, int) (bool, bool) - CustomerWelcome *CustomerWelcome - OtherPkg *go_sub_pkg.Person - OtherPkg2 go_sub_pkg.Greeter - OtherPkg3 *go_sub_pkg.Person - SendEmail EmailSender - SendEmailError *SendEmail - SomeEnv *string - WithEnv1 *SendEmail - WithEnv2 *SendEmail + AFunc func(int, int) (bool, bool) + Clock clockwork.Clock + CustomerWelcome *CustomerWelcome + DependsOnTime func(ParsedTime time.Time) time.Time + HTTPSignerClient *HTTPSignerClient + Now func() time.Time + OtherPkg *go_sub_pkg.Person + OtherPkg2 go_sub_pkg.Greeter + OtherPkg3 *go_sub_pkg.Person + ParsedTime func(value string) time.Time + SendEmail EmailSender + SendEmailError *SendEmail + Signer func(req *http.Request) *Signer + SomeEnv *string + WhatsTheTime *WhatsTheTime + WithEnv1 *SendEmail + WithEnv2 *SendEmail } -var DefaultContainer = &Container{} +var DefaultContainer = NewContainer() +func NewContainer() *Container { + return &Container{DependsOnTime: func(ParsedTime time.Time) time.Time { + service := ParsedTime + return service + }, Now: func() time.Time { + service := time.Now() + return service + }, ParsedTime: func(value string) time.Time { + service, err := time.Parse(time.RFC822, value) + if err != nil { + return time.Now() + } + return service + }, Signer: func(req *http.Request) *Signer { + service := NewSigner(req) + return service + }} +} func (container *Container) GetAFunc() func(int, int) (bool, bool) { if container.AFunc == nil { service := func(a, b int) (c, d bool) { @@ -34,6 +61,13 @@ func (container *Container) GetAFunc() func(int, int) (bool, bool) { } return container.AFunc } +func (container *Container) GetClock() clockwork.Clock { + if container.Clock == nil { + service := clockwork.NewRealClock() + container.Clock = service + } + return container.Clock +} func (container *Container) GetCustomerWelcome() *CustomerWelcome { if container.CustomerWelcome == nil { service := NewCustomerWelcome(container.GetSendEmail()) @@ -41,9 +75,19 @@ func (container *Container) GetCustomerWelcome() *CustomerWelcome { } return container.CustomerWelcome } +func (container *Container) GetDependsOnTime() time.Time { + return container.DependsOnTime(container.GetParsedTime("13 Jan 06 15:04 MST")) +} +func (container *Container) GetHTTPSignerClient() *HTTPSignerClient { + if container.HTTPSignerClient == nil { + service := &HTTPSignerClient{} + service.CreateSigner = container.Signer + container.HTTPSignerClient = service + } + return container.HTTPSignerClient +} func (container *Container) GetNow() time.Time { - service := time.Now() - return service + return container.Now() } func (container *Container) GetOtherPkg() *go_sub_pkg.Person { if container.OtherPkg == nil { @@ -66,6 +110,9 @@ func (container *Container) GetOtherPkg3() go_sub_pkg.Person { } return *container.OtherPkg3 } +func (container *Container) GetParsedTime(value string) time.Time { + return container.ParsedTime(value) +} func (container *Container) GetSendEmail() EmailSender { if container.SendEmail == nil { service := &SendEmail{} @@ -84,6 +131,9 @@ func (container *Container) GetSendEmailError() *SendEmail { } return container.SendEmailError } +func (container *Container) GetSigner(req *http.Request) *Signer { + return container.Signer(req) +} func (container *Container) GetSomeEnv() string { if container.SomeEnv == nil { service := os.Getenv("ShouldBeSet") @@ -91,6 +141,14 @@ func (container *Container) GetSomeEnv() string { } return *container.SomeEnv } +func (container *Container) GetWhatsTheTime() *WhatsTheTime { + if container.WhatsTheTime == nil { + service := &WhatsTheTime{} + service.clock = container.GetClock() + container.WhatsTheTime = service + } + return container.WhatsTheTime +} func (container *Container) GetWithEnv1() SendEmail { if container.WithEnv1 == nil { service := SendEmail{} diff --git a/dingotest/dingo.yml b/dingotest/dingo.yml index 7df4645..b2aea7b 100644 --- a/dingotest/dingo.yml +++ b/dingotest/dingo.yml @@ -28,6 +28,19 @@ services: returns: time.Now() scope: prototype + ParsedTime: + type: time.Time + scope: prototype + arguments: + value: string + returns: time.Parse(time.RFC822, value) + error: return time.Now() + + DependsOnTime: + type: time.Time + scope: prototype + returns: '@{ParsedTime("13 Jan 06 15:04 MST")}' + OtherPkg: type: '*github.com/elliotchance/dingo/dingotest/go-sub-pkg.Person' @@ -53,3 +66,26 @@ services: return } + + Clock: + interface: github.com/jonboulle/clockwork.Clock + returns: clockwork.NewRealClock() + + WhatsTheTime: + type: '*WhatsTheTime' + properties: + clock: '@{Clock}' + + Signer: + type: '*Signer' + scope: prototype # Create a new Signer each time + arguments: # Define the dependencies at runtime. + req: '*http.Request' + import: + - net/http + returns: NewSigner(req) # Setup code can reference the runtime dependencies. + + HTTPSignerClient: + type: '*HTTPSignerClient' + properties: + CreateSigner: '@{Signer}' diff --git a/dingotest/dingo_test.go b/dingotest/dingo_test.go index 0472ca7..0af8afb 100644 --- a/dingotest/dingo_test.go +++ b/dingotest/dingo_test.go @@ -2,10 +2,12 @@ package dingotest_test import ( "github.com/elliotchance/dingo/dingotest" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "os" "testing" + "time" ) type FakeEmailSender struct { @@ -27,7 +29,7 @@ func TestCustomerWelcome_Welcome(t *testing.T) { emailer.On("Send", "bob@smith.com", "Welcome", "Hi, Bob!").Return(nil) - container := &dingotest.Container{} + container := dingotest.NewContainer() container.SendEmail = emailer welcomer := container.GetCustomerWelcome() @@ -43,7 +45,7 @@ func TestDefaultContainer(t *testing.T) { } func TestContainer_GetSendEmail(t *testing.T) { - container := &dingotest.Container{} + container := dingotest.NewContainer() assert.Nil(t, container.SendEmail) @@ -63,7 +65,7 @@ func TestContainer_GetSendEmail(t *testing.T) { } func TestContainer_GetCustomerWelcome(t *testing.T) { - container := &dingotest.Container{} + container := dingotest.NewContainer() assert.Nil(t, container.SendEmail) assert.Nil(t, container.CustomerWelcome) @@ -81,30 +83,71 @@ func TestContainer_GetCustomerWelcome(t *testing.T) { } func TestContainer_GetWithEnv1(t *testing.T) { - container := &dingotest.Container{} + container := dingotest.NewContainer() service := container.GetWithEnv1() assert.Equal(t, "qux", service.From) } func TestContainer_GetWithEnv2(t *testing.T) { - container := &dingotest.Container{} + container := dingotest.NewContainer() service := container.GetWithEnv2() assert.Equal(t, "foo-qux-bar", service.From) } func TestContainer_GetSomeEnv(t *testing.T) { - container := &dingotest.Container{} + container := dingotest.NewContainer() service := container.GetSomeEnv() assert.Equal(t, "qux", service) } -func TestContainer_Now(t *testing.T) { - container := &dingotest.Container{} +func TestContainer_GetNow(t *testing.T) { + container := dingotest.NewContainer() service1 := container.GetNow() service2 := container.GetNow() assert.NotEqual(t, service1, service2) } + +func TestContainer_GetParsedTime(t *testing.T) { + container := dingotest.NewContainer() + + t.Run("Success", func(t *testing.T) { + tm := container.GetParsedTime("02 Jan 06 15:04 MST") + assert.Equal(t, "2006-01-02 15:04:00 +0000 MST", tm.String()) + }) + + t.Run("Error", func(t *testing.T) { + tm := container.GetParsedTime("bad format") + assert.WithinDuration(t, tm, time.Now(), time.Second) + }) +} + +func TestContainer_GetDependsOnTime(t *testing.T) { + t.Run("Success", func(t *testing.T) { + container := dingotest.NewContainer() + + tm := container.GetDependsOnTime() + assert.Equal(t, "2006-01-13 15:04:00 +0000 MST", tm.String()) + }) + + t.Run("Override", func(t *testing.T) { + container := dingotest.NewContainer() + container.ParsedTime = func(value string) time.Time { + return time.Now() + } + + tm := container.GetDependsOnTime() + assert.WithinDuration(t, tm, time.Now(), time.Second) + }) +} + +func TestContainer_GetWhatsTheTime(t *testing.T) { + container := dingotest.NewContainer() + container.Clock = clockwork.NewFakeClock() + + actual := container.GetWhatsTheTime().InRFC1123() + assert.Equal(t, "Wed, 04 Apr 1984 00:00:00 UTC", actual) +} diff --git a/dingotest/http.go b/dingotest/http.go new file mode 100644 index 0000000..15c20c3 --- /dev/null +++ b/dingotest/http.go @@ -0,0 +1,30 @@ +package dingotest + +import ( + "net/http" + "time" +) + +type HTTPSignerClient struct { + CreateSigner func(req *http.Request) *Signer +} + +func (c *HTTPSignerClient) Do(req *http.Request) (*http.Response, error) { + signer := c.CreateSigner(req) + req.Header.Set("Authorization", signer.Auth()) + + return http.DefaultClient.Do(req) +} + +type Signer struct { + req *http.Request +} + +func NewSigner(req *http.Request) *Signer { + return &Signer{req: req} +} + +// Produces something like "Mon Jan 2 15:04:05 2006 POST" +func (signer *Signer) Auth() string { + return time.Now().Format(time.ANSIC) + " " + signer.req.Method +} diff --git a/dingotest/whats_the_time.go b/dingotest/whats_the_time.go new file mode 100644 index 0000000..5b2700c --- /dev/null +++ b/dingotest/whats_the_time.go @@ -0,0 +1,14 @@ +package dingotest + +import ( + "github.com/jonboulle/clockwork" + "time" +) + +type WhatsTheTime struct { + clock clockwork.Clock +} + +func (t *WhatsTheTime) InRFC1123() string { + return t.clock.Now().Format(time.RFC1123) +} diff --git a/expression.go b/expression.go new file mode 100644 index 0000000..12586fc --- /dev/null +++ b/expression.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "github.com/elliotchance/pie/pie" + "go/ast" + "golang.org/x/tools/go/ast/astutil" + "regexp" + "strings" +) + +type Expression string + +func (e Expression) DependencyNames() (deps []string) { + for _, v := range regexp.MustCompile(`@{(.*?)}`).FindAllStringSubmatch(string(e), -1) { + parts := strings.Split(v[1], "(") + deps = append(deps, parts[0]) + } + + return pie.Strings(deps).Unique() +} + +func (e Expression) Dependencies() (deps []string) { + for _, v := range regexp.MustCompile(`@{(.*?)}`).FindAllStringSubmatch(string(e), -1) { + deps = append(deps, v[1]) + } + + return pie.Strings(deps).Unique() +} + +func (e Expression) performSubstitutions(services Services, fromArgs bool) string { + stmt := string(e) + + // Replace environment variables. + stmt = replaceAllStringSubmatchFunc( + regexp.MustCompile(`\${(.*?)}`), stmt, func(i []string) string { + astutil.AddImport(fset, file, "os") + + return fmt.Sprintf("os.Getenv(\"%s\")", i[1]) + }) + + // Replace service names. + stmt = replaceAllStringSubmatchFunc( + regexp.MustCompile(`@{(.*?)}`), stmt, func(i []string) string { + if fromArgs { + return strings.Split(i[1], "(")[0] + } + + if strings.Contains(i[1], "(") { + return fmt.Sprintf("container.Get%s", i[1]) + } + + if _, ok := services[i[1]].ContainerFieldType(services).(*ast.FuncType); ok { + return fmt.Sprintf("container.%s", i[1]) + } + + return fmt.Sprintf("container.Get%s()", i[1]) + }) + + return stmt +} diff --git a/file.go b/file.go new file mode 100644 index 0000000..7a76eb8 --- /dev/null +++ b/file.go @@ -0,0 +1,5 @@ +package main + +type File struct { + Services Services +} diff --git a/go.mod b/go.mod index 11785a9..6061f61 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,13 @@ module github.com/elliotchance/dingo go 1.12 require ( + github.com/elliotchance/pie v1.34.0 + github.com/elliotchance/testify-stats v1.0.0 + github.com/go-yaml/yaml v2.1.0+incompatible + github.com/jonboulle/clockwork v0.1.0 + github.com/kr/pretty v0.1.0 // indirect github.com/stretchr/testify v1.3.0 golang.org/x/tools v0.0.0-20190530001615-b97706b7f64d - gopkg.in/yaml.v2 v2.2.2 + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v2 v2.2.2 // indirect ) diff --git a/go.sum b/go.sum index 782b7b2..e196d7d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,20 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/elliotchance/pie v1.34.0 h1:BZEckuK+QlmOwVs2fWMaX6O5nRKkQOAJ2lTO0BH78pQ= +github.com/elliotchance/pie v1.34.0/go.mod h1:W/nLuTGZ1dLKzRS0Z2g2N2evWzMenuDnBhk0s6Y9k54= +github.com/elliotchance/testify-stats v1.0.0 h1:CMcRBfQIB0WwT1+aY38MM4ShFqhPyP6jkHRytSvXLzI= +github.com/elliotchance/testify-stats v1.0.0/go.mod h1:Mc25k7L4E65uf6CfW+s/pY04XcoiqQBrfIRsWQcgweA= +github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= +github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= +github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo= +github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= @@ -14,5 +29,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190530001615-b97706b7f64d h1:uVrwmEsn22e4befeQ6fUruML6nom3W7Z/KjHzNEJmAw= golang.org/x/tools v0.0.0-20190530001615-b97706b7f64d/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/main.go b/main.go index d37272b..b956268 100644 --- a/main.go +++ b/main.go @@ -2,28 +2,23 @@ package main import ( "fmt" + "github.com/go-yaml/yaml" "go/ast" "go/parser" "go/printer" "go/token" "golang.org/x/tools/go/ast/astutil" - "gopkg.in/yaml.v2" "io/ioutil" "log" "os" "path/filepath" "regexp" - "sort" "strings" ) var fset *token.FileSet var file *ast.File -type File struct { - Services map[string]Service -} - func replaceAllStringSubmatchFunc(re *regexp.Regexp, str string, repl func([]string) string) string { result := "" lastIndex := 0 @@ -41,24 +36,6 @@ func replaceAllStringSubmatchFunc(re *regexp.Regexp, str string, repl func([]str return result + str[lastIndex:] } -func resolveStatement(stmt string) string { - // Replace environment variables. - stmt = replaceAllStringSubmatchFunc( - regexp.MustCompile(`\${(.*?)}`), stmt, func(i []string) string { - astutil.AddImport(fset, file, "os") - - return fmt.Sprintf("os.Getenv(\"%s\")", i[1]) - }) - - // Replace service names. - stmt = replaceAllStringSubmatchFunc( - regexp.MustCompile(`@{(.*?)}`), stmt, func(i []string) string { - return fmt.Sprintf("container.Get%s()", i[1]) - }) - - return stmt -} - func main() { dingoYMLPath := "dingo.yml" outputFile := "dingo.go" @@ -81,65 +58,12 @@ func main() { log.Fatalln("parser:", err) } - // Sort services to the output file is neat and deterministic. - var serviceNames []string - for name := range all.Services { - serviceNames = append(serviceNames, name) - } - - sort.Strings(serviceNames) - - // type Container struct - var containerFields []*ast.Field - for _, serviceName := range serviceNames { - definition := all.Services[serviceName] - - switch definition.Scope { - case ScopeNotSet, ScopeContainer: - containerFields = append(containerFields, &ast.Field{ - Names: []*ast.Ident{ - {Name: serviceName}, - }, - Type: &ast.Ident{ - Name: definition.InterfaceOrLocalEntityPointerType(), - }, - }) - - case ScopePrototype: - // Do not create a property for this because it has to be created - // every time. - } - } - - file.Decls = append(file.Decls, &ast.GenDecl{ - Tok: token.TYPE, - Specs: []ast.Spec{ - &ast.TypeSpec{ - Name: &ast.Ident{Name: "Container"}, - Type: &ast.StructType{ - Fields: &ast.FieldList{ - List: containerFields, - }, - }, - }, - }, - }) - - file.Decls = append(file.Decls, &ast.GenDecl{ - Tok: token.VAR, - Specs: []ast.Spec{ - &ast.ValueSpec{ - Names: []*ast.Ident{ - {Name: "DefaultContainer"}, - }, - Values: []ast.Expr{ - &ast.Ident{Name: "&Container{}"}, - }, - }, - }, - }) + file.Decls = append(file.Decls, + all.Services.astContainerStruct(), + all.Services.astDefaultContainer(), + all.Services.astNewContainerFunc()) - for _, serviceName := range serviceNames { + for _, serviceName := range all.Services.ServiceNames() { definition := all.Services[serviceName] // Add imports for type, interface and explicit imports. @@ -147,144 +71,23 @@ func main() { astutil.AddNamedImport(fset, file, shortName, packageName) } - returnTypeParts := strings.Split( - regexp.MustCompile(`/v\d+\.`).ReplaceAllString(string(definition.Type), "."), "/") - returnType := returnTypeParts[len(returnTypeParts)-1] - if strings.HasPrefix(string(definition.Type), "*") && !strings.HasPrefix(returnType, "*") { - returnType = "*" + returnType - } - - var stmts, instantiation []ast.Stmt - serviceVariable := "container." + serviceName - serviceTempVariable := "service" - - // Instantiation - if definition.Returns == "" { - instantiation = []ast.Stmt{ - &ast.AssignStmt{ - Tok: token.DEFINE, - Lhs: []ast.Expr{&ast.Ident{Name: serviceTempVariable}}, - Rhs: []ast.Expr{ - &ast.CompositeLit{ - Type: &ast.Ident{Name: definition.Type.CreateLocalEntityType()}, - }, - }, - }, - } - } else { - lhs := []ast.Expr{&ast.Ident{Name: serviceTempVariable}} - - if definition.Error != "" { - lhs = append(lhs, &ast.Ident{Name: "err"}) - } - - instantiation = []ast.Stmt{ - &ast.AssignStmt{ - Tok: token.DEFINE, - Lhs: lhs, - Rhs: []ast.Expr{&ast.Ident{Name: resolveStatement(definition.Returns)}}, - }, - } - - if definition.Error != "" { - instantiation = append(instantiation, &ast.IfStmt{ - Cond: &ast.Ident{Name: "err != nil"}, - Body: &ast.BlockStmt{ - List: []ast.Stmt{ - &ast.ExprStmt{ - X: &ast.Ident{Name: definition.Error}, - }, - }, - }, - }) - } - } - - // Properties - for _, property := range definition.SortedProperties() { - instantiation = append(instantiation, &ast.AssignStmt{ - Tok: token.ASSIGN, - Lhs: []ast.Expr{&ast.Ident{Name: serviceTempVariable + "." + property.Name}}, - Rhs: []ast.Expr{&ast.Ident{Name: resolveStatement(property.Value)}}, - }) - } - - // Scope - switch definition.Scope { - case ScopeNotSet, ScopeContainer: - if definition.Type.IsPointer() || definition.Interface != "" { - instantiation = append(instantiation, &ast.AssignStmt{ - Tok: token.ASSIGN, - Lhs: []ast.Expr{&ast.Ident{Name: serviceVariable}}, - Rhs: []ast.Expr{&ast.Ident{Name: serviceTempVariable}}, - }) - } else { - instantiation = append(instantiation, &ast.AssignStmt{ - Tok: token.ASSIGN, - Lhs: []ast.Expr{&ast.Ident{Name: serviceVariable}}, - Rhs: []ast.Expr{&ast.Ident{Name: "&" + serviceTempVariable}}, - }) - } - - stmts = append(stmts, &ast.IfStmt{ - Cond: &ast.Ident{Name: serviceVariable + " == nil"}, - Body: &ast.BlockStmt{ - List: instantiation, - }, - }) - - // Returns - if definition.Type.IsPointer() || definition.Interface != "" { - stmts = append(stmts, &ast.ReturnStmt{ - Results: []ast.Expr{ - &ast.Ident{Name: serviceVariable}, - }, - }) - } else { - stmts = append(stmts, &ast.ReturnStmt{ - Results: []ast.Expr{ - &ast.Ident{Name: "*" + serviceVariable}, - }, - }) - } - - case ScopePrototype: - stmts = append(stmts, instantiation...) - - // Returns - stmts = append(stmts, &ast.ReturnStmt{ - Results: []ast.Expr{ - &ast.Ident{Name: "service"}, - }, - }) - } - file.Decls = append(file.Decls, &ast.FuncDecl{ - Name: &ast.Ident{Name: "Get" + serviceName}, + Name: newIdent("Get" + serviceName), Recv: &ast.FieldList{ List: []*ast.Field{ { Names: []*ast.Ident{ - {Name: "container"}, + newIdent("container"), }, - Type: &ast.Ident{Name: "*Container"}, + Type: newIdent("*Container"), }, }, }, Type: &ast.FuncType{ - Results: &ast.FieldList{ - List: []*ast.Field{ - { - Type: &ast.Ident{ - Name: definition.InterfaceOrLocalEntityType(), - }, - }, - }, - }, - }, - Body: &ast.BlockStmt{ - List: stmts, + Params: definition.astArguments(), + Results: newFieldList(definition.InterfaceOrLocalEntityType(all.Services, false)), }, + Body: definition.astFunctionBody(all.Services, serviceName, serviceName), }) } diff --git a/property.go b/property.go index d2d4a1d..a9a32a9 100644 --- a/property.go +++ b/property.go @@ -1,5 +1,6 @@ package main type Property struct { - Name, Value string + Name string + Value Expression } diff --git a/service.go b/service.go index 350481d..0b109b4 100644 --- a/service.go +++ b/service.go @@ -2,7 +2,10 @@ package main import ( "fmt" + "go/ast" + "go/token" "sort" + "strings" ) const ( @@ -12,21 +15,50 @@ const ( ) type Service struct { + Arguments Arguments Error string Import []string Interface Type - Properties map[string]string - Returns string + Properties map[string]Expression + Returns Expression Scope string Type Type } -func (service *Service) InterfaceOrLocalEntityType() string { +func (service *Service) ContainerFieldType(services Services) ast.Expr { + scope := service.Scope + if scope == ScopeNotSet { + scope = ScopeContainer + } + + if scope == ScopeContainer && len(service.Arguments) == 0 { + return newIdent(service.InterfaceOrLocalEntityPointerType()) + } + + return service.astFunctionPrototype(services) +} + +func (service *Service) InterfaceOrLocalEntityType(services Services, recurse bool) string { + localEntityType := service.Type.LocalEntityType() if service.Interface != "" { - return service.Interface.LocalEntityType() + localEntityType = service.Interface.LocalEntityType() } - return service.Type.LocalEntityType() + if len(service.Arguments) > 0 && recurse { + var args []string + + for _, dep := range service.Returns.Dependencies() { + ty := services[dep].InterfaceOrLocalEntityType(services, false) + args = append(args, fmt.Sprintf("%s %s", dep, ty)) + } + + args = append(args, service.Arguments.GoArguments()...) + + return fmt.Sprintf("func(%v) %s", strings.Join(args, ", "), + localEntityType) + } + + return localEntityType } func (service *Service) InterfaceOrLocalEntityPointerType() string { @@ -89,3 +121,168 @@ func (service *Service) Validate() error { return nil } + +func (service *Service) astArguments() *ast.FieldList { + funcParams := &ast.FieldList{ + List: []*ast.Field{}, + } + + for arg, ty := range service.Arguments { + funcParams.List = append(funcParams.List, &ast.Field{ + Type: &ast.Ident{ + Name: string(arg + " " + ty.String()), + }, + }) + } + + return funcParams +} + +func (service *Service) astDependencyArguments(services Services) *ast.FieldList { + funcParams := &ast.FieldList{ + List: []*ast.Field{}, + } + + for _, dep := range service.Returns.DependencyNames() { + funcParams.List = append(funcParams.List, &ast.Field{ + Type: newIdent(dep + " " + services[dep].InterfaceOrLocalEntityType(services, false)), + }) + } + + return funcParams +} + +func (service *Service) astAllArguments(services Services) *ast.FieldList { + deps := service.astDependencyArguments(services) + args := service.astArguments() + + return &ast.FieldList{ + List: append(deps.List, args.List...), + } +} + +func (service *Service) astFunctionPrototype(services Services) *ast.FuncType { + ty := Type(service.InterfaceOrLocalEntityType(services, true)) + if ty.IsFunction() { + args, returns := ty.parseFunctionType() + + return &ast.FuncType{ + Params: newFieldList(args), + Results: newFieldList(returns...), + } + } + + return &ast.FuncType{ + Params: service.astAllArguments(services), + Results: newFieldList(string(ty)), + } +} + +func (service *Service) astFunctionBody(services Services, name, serviceName string) *ast.BlockStmt { + if name != "" && service.Scope == ScopePrototype { + var arguments []string + for _, dep := range service.Returns.Dependencies() { + arguments = append(arguments, fmt.Sprintf("container.Get%s", dep)) + } + arguments = append(arguments, service.Arguments.Names()...) + + return newBlock( + newReturn(newIdent("container." + serviceName + "(" + strings.Join(arguments, ", ") + ")")), + ) + } + + var stmts, instantiation []ast.Stmt + serviceVariable := "container." + name + serviceTempVariable := "service" + + // Instantiation + if service.Returns == "" { + instantiation = []ast.Stmt{ + &ast.AssignStmt{ + Tok: token.DEFINE, + Lhs: []ast.Expr{newIdent(serviceTempVariable)}, + Rhs: []ast.Expr{ + &ast.CompositeLit{ + Type: newIdent(service.Type.CreateLocalEntityType()), + }, + }, + }, + } + } else { + lhs := []ast.Expr{newIdent(serviceTempVariable)} + + if service.Error != "" { + lhs = append(lhs, newIdent("err")) + } + + instantiation = []ast.Stmt{ + &ast.AssignStmt{ + Tok: token.DEFINE, + Lhs: lhs, + Rhs: []ast.Expr{ + newIdent(service.Returns.performSubstitutions(services, name == "")), + }, + }, + } + + if service.Error != "" { + instantiation = append(instantiation, &ast.IfStmt{ + Cond: newIdent("err != nil"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.ExprStmt{ + X: newIdent(service.Error), + }, + }, + }, + }) + } + } + + // Properties + for _, property := range service.SortedProperties() { + instantiation = append(instantiation, &ast.AssignStmt{ + Tok: token.ASSIGN, + Lhs: []ast.Expr{&ast.Ident{Name: serviceTempVariable + "." + property.Name}}, + Rhs: []ast.Expr{&ast.Ident{Name: property.Value.performSubstitutions(services, name == "")}}, + }) + } + + // Scope + switch service.Scope { + case ScopeNotSet, ScopeContainer: + if service.Type.IsPointer() || service.Interface != "" { + instantiation = append(instantiation, &ast.AssignStmt{ + Tok: token.ASSIGN, + Lhs: []ast.Expr{&ast.Ident{Name: serviceVariable}}, + Rhs: []ast.Expr{&ast.Ident{Name: serviceTempVariable}}, + }) + } else { + instantiation = append(instantiation, &ast.AssignStmt{ + Tok: token.ASSIGN, + Lhs: []ast.Expr{&ast.Ident{Name: serviceVariable}}, + Rhs: []ast.Expr{&ast.Ident{Name: "&" + serviceTempVariable}}, + }) + } + + stmts = append(stmts, &ast.IfStmt{ + Cond: &ast.Ident{Name: serviceVariable + " == nil"}, + Body: &ast.BlockStmt{ + List: instantiation, + }, + }) + + // Returns + if service.Type.IsPointer() || service.Interface != "" { + stmts = append(stmts, newReturn(newIdent(serviceVariable))) + } else { + stmts = append(stmts, newReturn(newIdent("*"+serviceVariable))) + } + + case ScopePrototype: + stmts = append(stmts, instantiation...) + stmts = append(stmts, newReturn(newIdent("service"))) + } + + return newBlock(stmts...) +} diff --git a/service_test.go b/service_test.go index 5ed6b3d..99dd08b 100644 --- a/service_test.go +++ b/service_test.go @@ -3,6 +3,7 @@ package main import ( "errors" "github.com/stretchr/testify/assert" + "go/ast" "testing" ) @@ -41,3 +42,229 @@ func TestService_Validate(t *testing.T) { }) } } + +func TestService_ContainerFieldType(t *testing.T) { + for testName, test := range map[string]struct { + services Services + containerFieldType ast.Expr + }{ + "StructWithoutScope": { + services: Services{ + "A": { + Scope: ScopeNotSet, + Type: "*SendEmail", + }, + }, + + // SendEmail is already a pointer, so it will be nil until + // initialised. + containerFieldType: newIdent("*SendEmail"), + }, + "StructContainer": { + services: Services{ + "A": { + Scope: ScopeContainer, + Type: "SendEmail", + }, + }, + + // SendEmail is not a pointer, so we need to make it one so we know + // when it is initialised. + containerFieldType: newIdent("*SendEmail"), + }, + "StructPrototype": { + services: Services{ + "A": { + Scope: ScopePrototype, + Type: "*foo.Bar", + }, + }, + + // It must be wrapped in a function so it is created each time. + containerFieldType: &ast.FuncType{ + Params: newFieldList(), + Results: newFieldList("*foo.Bar"), + }, + }, + "StructContainerWithArguments": { + services: Services{ + "A": { + Scope: ScopeContainer, + Type: "SendEmail", + Arguments: Arguments{ + "foo": "int", + }, + }, + }, + + // SendEmail is not a pointer, so we need to make it one so we know + // when it is initialised. + containerFieldType: &ast.FuncType{ + Params: newFieldList("foo int"), + Results: newFieldList("SendEmail"), + }, + }, + "StructPrototypeWithArguments": { + services: Services{ + "A": { + Scope: ScopePrototype, + Type: "*foo.Bar", + Arguments: Arguments{ + "foo": "int", + "bar": "float64", + }, + }, + }, + + // It must be wrapped in a function so it is created each time. + containerFieldType: &ast.FuncType{ + Params: newFieldList("bar float64, foo int"), + Results: newFieldList("*foo.Bar"), + }, + }, + "StructPrototypeWithArgumentsAndDeps1": { + services: Services{ + "A": { + Scope: ScopePrototype, + Type: "*foo.Bar", + Returns: "@{B}", + Arguments: Arguments{ + "foo": "int", + "bar": "float64", + }, + }, + "B": { + Scope: ScopeContainer, + Type: "foo.Baz", + }, + }, + containerFieldType: &ast.FuncType{ + Params: newFieldList("B foo.Baz, bar float64, foo int"), + Results: newFieldList("*foo.Bar"), + }, + }, + "StructPrototypeWithArgumentsAndDeps2": { + services: Services{ + "A": { + Scope: ScopePrototype, + Type: "*foo.Bar", + Returns: "@{B}", + Arguments: Arguments{ + "foo": "int", + "bar": "float64", + }, + }, + "B": { + Scope: ScopePrototype, + Type: "foo.Baz", + }, + }, + + containerFieldType: &ast.FuncType{ + Params: newFieldList("B foo.Baz, bar float64, foo int"), + Results: newFieldList("*foo.Bar"), + }, + }, + "StructPrototypeWithArgumentsAndDeps3": { + services: Services{ + "A": { + Scope: ScopePrototype, + Type: "*foo.Bar", + Returns: "@{B}", + Arguments: Arguments{ + "foo": "int", + "bar": "float64", + }, + }, + "B": { + Scope: ScopePrototype, + Interface: "Bazer", + Arguments: Arguments{ + "baz": "time.Time", + }, + }, + }, + + containerFieldType: &ast.FuncType{ + Params: newFieldList("B Bazer, bar float64, foo int"), + Results: newFieldList("*foo.Bar"), + }, + }, + "InterfaceWithoutScope": { + services: Services{ + "A": { + Scope: ScopeNotSet, + Interface: "Emailer", + }, + }, + + // Interfaces can be nil, so no need to turn it into a pointer. + containerFieldType: newIdent("Emailer"), + }, + "InterfaceContainer": { + services: Services{ + "A": { + Scope: ScopeContainer, + Interface: "Emailer", + }, + }, + + // Interfaces can be nil, so no need to turn it into a pointer. + containerFieldType: newIdent("Emailer"), + }, + "InterfacePrototype": { + services: Services{ + "A": { + Scope: ScopePrototype, + Interface: "Emailer", + }, + }, + + // A func that returns the interface. + containerFieldType: &ast.FuncType{ + Params: newFieldList(), + Results: newFieldList("Emailer"), + }, + }, + "InterfaceContainerWithArguments": { + services: Services{ + "A": { + Scope: ScopeContainer, + Interface: "Emailer", + Arguments: Arguments{ + "foo": "int", + }, + }, + }, + + // Interfaces can be nil, so no need to turn it into a pointer. + containerFieldType: &ast.FuncType{ + Params: newFieldList("foo int"), + Results: newFieldList("Emailer"), + }, + }, + "InterfacePrototypeWithArguments": { + services: Services{ + "A": { + Scope: ScopePrototype, + Interface: "Emailer", + Arguments: Arguments{ + "foo": "int", + "bar": "float64", + }, + }, + }, + + // A func that returns the interface. + containerFieldType: &ast.FuncType{ + Params: newFieldList("bar float64, foo int"), + Results: newFieldList("Emailer"), + }, + }, + } { + t.Run(testName, func(t *testing.T) { + actual := test.services["A"].ContainerFieldType(test.services) + assert.Equal(t, test.containerFieldType, actual) + }) + } +} diff --git a/services.go b/services.go new file mode 100644 index 0000000..3c5f54c --- /dev/null +++ b/services.go @@ -0,0 +1,94 @@ +package main + +import ( + "go/ast" + "go/token" + "sort" +) + +type Services map[string]*Service + +func (services Services) ServiceNames() []string { + // Sort services to the output file is neat and deterministic. + var serviceNames []string + for name := range services { + serviceNames = append(serviceNames, name) + } + + sort.Strings(serviceNames) + + return serviceNames +} + +func (services Services) ServicesWithScope(scope string) Services { + ss := make(Services) + + for serviceName, service := range services { + if service.Scope == scope { + ss[serviceName] = service + } + } + + return ss +} + +// astContainer creates the Container struct. +func (services Services) astContainerStruct() *ast.GenDecl { + var containerFields []*ast.Field + for _, serviceName := range services.ServiceNames() { + service := services[serviceName] + + containerFields = append(containerFields, &ast.Field{ + Names: []*ast.Ident{ + {Name: serviceName}, + }, + Type: service.ContainerFieldType(services), + }) + } + + return &ast.GenDecl{ + Tok: token.TYPE, + Specs: []ast.Spec{ + &ast.TypeSpec{ + Name: newIdent("Container"), + Type: &ast.StructType{ + Fields: &ast.FieldList{ + List: containerFields, + }, + }, + }, + }, + } +} + +func (services Services) astNewContainerFunc() *ast.FuncDecl { + fields := make(map[string]ast.Expr) + + for _, serviceName := range services.ServicesWithScope(ScopePrototype).ServiceNames() { + service := services[serviceName] + fields[serviceName] = &ast.FuncLit{ + Type: service.astFunctionPrototype(services), + Body: service.astFunctionBody(services, "", serviceName), + } + } + + return newFunc("NewContainer", nil, []string{"*Container"}, newBlock( + newReturn(newCompositeLit("&Container", fields)), + )) +} + +func (services Services) astDefaultContainer() *ast.GenDecl { + return &ast.GenDecl{ + Tok: token.VAR, + Specs: []ast.Spec{ + &ast.ValueSpec{ + Names: []*ast.Ident{ + {Name: "DefaultContainer"}, + }, + Values: []ast.Expr{ + &ast.Ident{Name: "NewContainer()"}, + }, + }, + }, + } +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..8d73d78 --- /dev/null +++ b/util.go @@ -0,0 +1,75 @@ +package main + +import ( + "go/ast" +) + +func newIdent(name string) *ast.Ident { + return &ast.Ident{ + Name: name, + } +} + +func newFieldList(values ...string) *ast.FieldList { + fields := []*ast.Field{} + + for _, value := range values { + fields = append(fields, &ast.Field{ + Type: newIdent(value), + }) + } + + return &ast.FieldList{ + List: fields, + } +} + +func newFunc(name string, params []string, returns []string, body *ast.BlockStmt) *ast.FuncDecl { + return &ast.FuncDecl{ + Name: newIdent(name), + Type: &ast.FuncType{ + Params: newFieldList(params...), + Results: newFieldList(returns...), + }, + Body: body, + } +} + +func newReturn(expressions ...ast.Expr) *ast.ReturnStmt { + var results []ast.Expr + + for _, expr := range expressions { + results = append(results, expr) + } + + return &ast.ReturnStmt{ + Results: results, + } +} + +func newBlock(stmts ...ast.Stmt) *ast.BlockStmt { + return &ast.BlockStmt{ + List: stmts, + } +} + +func newCompositeLit(ty string, m map[string]ast.Expr) *ast.CompositeLit { + var exprs []ast.Expr + var keys []string + + for k := range m { + keys = append(keys, k) + } + + for _, k := range keys { + exprs = append(exprs, &ast.KeyValueExpr{ + Key: newIdent(k), + Value: m[k], + }) + } + + return &ast.CompositeLit{ + Type: newIdent(ty), + Elts: exprs, + } +}