diff --git a/.gitignore b/.gitignore index 9dd3b40..b4b5879 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,6 @@ go.work # goreleaser dist/ + +# temporary files +tmp/ diff --git a/alias.go b/alias.go index 9642b5c..91ce7c3 100644 --- a/alias.go +++ b/alias.go @@ -24,10 +24,10 @@ import ( // Engine initializes external store client and template. type Engine = engine.Engine -var ( - // NewConfigFromFile returns a new Config from file. - NewConfigFromFile = config.NewFromFile +// Config is the configuration for this library. +type Config = config.Config +var ( // NewEngine returns a new Engine. NewEngine = engine.New ) diff --git a/go.mod b/go.mod index db9e4bc..fddb09b 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.29.1 github.com/aws/aws-sdk-go-v2/service/ssm v1.50.4 github.com/spf13/cobra v1.8.0 + golang.org/x/crypto v0.3.0 sigs.k8s.io/yaml v1.4.0 ) @@ -36,6 +37,5 @@ require ( github.com/shopspring/decimal v1.2.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/crypto v0.3.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/internal/client/aws.go b/internal/client/aws.go index baedb1d..fed09a7 100644 --- a/internal/client/aws.go +++ b/internal/client/aws.go @@ -25,6 +25,7 @@ import ( kmsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) @@ -56,8 +57,8 @@ func newAwsClient(cfg *config.AwsConfig) (Client, error) { }, nil } -func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (Values, error) { - values := make(Values, len(c.parameterStoreValue)+len(c.secretsManagerValue)) +func (c awsClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + values := make(values.Values, len(c.parameterStoreValue)+len(c.secretsManagerValue)) for _, v := range c.parameterStoreValue { output, err := c.ssmClient.GetParameter(ctx, &ssm.GetParameterInput{ diff --git a/internal/client/aws_test.go b/internal/client/aws_test.go index 57a6c8b..6a101d6 100644 --- a/internal/client/aws_test.go +++ b/internal/client/aws_test.go @@ -25,6 +25,7 @@ import ( kmsTypes "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" + "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) @@ -109,7 +110,7 @@ func Test_awsClient_GetValues(t *testing.T) { name string fields fields args args - want Values + want values.Values wantErr bool }{ { @@ -128,7 +129,7 @@ func Test_awsClient_GetValues(t *testing.T) { ctx: context.Background(), ignoreNotFound: false, }, - want: Values{"ssmKey": "test", "kmsKey": "test"}, + want: values.Values{"ssmKey": "test", "kmsKey": "test"}, }, { name: "ok: json", @@ -156,7 +157,7 @@ func Test_awsClient_GetValues(t *testing.T) { ctx: context.Background(), ignoreNotFound: false, }, - want: Values{"ssmKey": map[string]any{"key": "value"}, "kmsKey": map[string]any{"key": "value"}}, + want: values.Values{"ssmKey": map[string]any{"key": "value"}, "kmsKey": map[string]any{"key": "value"}}, }, { name: "ok: ignore not found error", @@ -174,7 +175,7 @@ func Test_awsClient_GetValues(t *testing.T) { ctx: context.Background(), ignoreNotFound: true, }, - want: Values{}, + want: values.Values{}, }, { name: "error: return not found from ssm", diff --git a/internal/client/cache.go b/internal/client/cache.go new file mode 100644 index 0000000..b15cc07 --- /dev/null +++ b/internal/client/cache.go @@ -0,0 +1,60 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package client + +import ( + "context" + + "github.com/dwango/yashiro/internal/client/cache" + "github.com/dwango/yashiro/internal/values" +) + +type clientWithCache struct { + client Client + cache cache.Cache +} + +func newClientWithCache(client Client, cache cache.Cache) Client { + return &clientWithCache{ + client: client, + cache: cache, + } +} + +// GetValues implements Client. +func (c *clientWithCache) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + val, expired, err := c.cache.Load(ctx) + if err != nil { + return nil, err + } + + // if cache is empty, get values from external store. + if len(val) == 0 || expired { + val, err = c.client.GetValues(ctx, ignoreNotFound) + if err != nil { + return nil, err + } + } + + // save values to cache + if expired { + if err := c.cache.Save(ctx, val); err != nil { + return nil, err + } + } + + return val, nil +} diff --git a/internal/client/cache/cache.go b/internal/client/cache/cache.go new file mode 100644 index 0000000..3dfc2b3 --- /dev/null +++ b/internal/client/cache/cache.go @@ -0,0 +1,49 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cache + +import ( + "context" + "errors" + "fmt" + + "github.com/dwango/yashiro/internal/values" + "github.com/dwango/yashiro/pkg/config" +) + +var ( + ErrInvalidCacheType = errors.New("invalid cache type") +) + +type Cache interface { + // Load returns values from cache and whether or not cache is expired. If cache is empty, + // returned values is empty and expired=true. + Load(ctx context.Context) (values.Values, bool, error) + + // Save saves values to cache. + Save(ctx context.Context, val values.Values) error +} + +func New(cfg config.CacheConfig) (Cache, error) { + switch cfg.Type { + case config.CacheTypeUnspecified, config.CacheTypeMemory: + return newMemoryCache() + case config.CacheTypeFile: + return newFileCache(cfg.File) + default: + return nil, fmt.Errorf("%w: %s", ErrInvalidCacheType, cfg.Type) + } +} diff --git a/internal/client/cache/file.go b/internal/client/cache/file.go new file mode 100644 index 0000000..61cb315 --- /dev/null +++ b/internal/client/cache/file.go @@ -0,0 +1,237 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cache + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/json" + "errors" + "io" + "os" + "time" + + "github.com/dwango/yashiro/internal/values" + "github.com/dwango/yashiro/pkg/config" + "golang.org/x/crypto/bcrypt" +) + +const ( + cacheFileName = "values" + keyFileName = "key" + keyHashFileName = ".key_hash" +) + +var defaultCacheBasePath string + +type fileCache struct { + cacheBasePath string + cipherBlock cipher.Block + expired bool +} + +func newFileCache(cfg config.FileCacheConfig) (Cache, error) { + fc := &fileCache{ + cacheBasePath: defaultCacheBasePath, + } + + if len(cfg.CachePath) != 0 { + fc.cacheBasePath = cfg.CachePath + } + // create cache directory + if err := os.MkdirAll(fc.cacheBasePath, 0777); err != nil { + return nil, err + } + + // read or create key + key, err := fc.readOrCreateKey() + if err != nil { + return nil, err + } + + fc.cipherBlock, err = aes.NewCipher(key) + if err != nil { + return nil, err + } + + return fc, nil +} + +const ( + // 30 days + expiredDuration time.Duration = 30 * 24 * time.Hour +) + +// Load implements Cache. +func (f *fileCache) Load(_ context.Context) (values.Values, bool, error) { + fInfo, err := f.getFileStat(cacheFileName) + if err != nil { + f.expired = true + return nil, f.expired, nil + } + // check if cache is expired + f.expired = time.Since(fInfo.ModTime().Local()) > expiredDuration + + cacheCipherText, err := f.readFile(cacheFileName) + if err != nil { + return nil, false, err + } + + val, err := f.decryptCache(cacheCipherText) + if err != nil { + return nil, false, err + } + + return val, f.expired, nil +} + +// Save implements Cache. +func (f *fileCache) Save(_ context.Context, val values.Values) error { + if !f.expired { + return nil + } + + encryptedCache, err := f.encryptCache(val) + if err != nil { + return err + } + + if err := f.writeToFile(cacheFileName, encryptedCache); err != nil { + return err + } + + return nil +} + +func (f *fileCache) readOrCreateKey() ([]byte, error) { + var key []byte + // check key file exists + if _, err := f.getFileStat(keyFileName); err != nil { + key = make([]byte, 32) + + // create key file + if _, err := rand.Read(key); err != nil { + return nil, err + } + if err := f.writeToFile(keyFileName, key); err != nil { + return nil, err + } + + // hashing key + keyHash, err := bcrypt.GenerateFromPassword(key, 5) + if err != nil { + return nil, err + } + if err := f.writeToFile(keyHashFileName, keyHash); err != nil { + return nil, err + } + + return key, nil + } + + var err error + // read key file + key, err = f.readFile(keyFileName) + if err != nil { + return nil, err + } + + // read key hash file + keyHash, err := f.readFile(keyHashFileName) + if err != nil { + return nil, err + } + + // check key is not tampered + if err := bcrypt.CompareHashAndPassword(keyHash, key); err != nil { + return nil, err + } + + return key, nil +} + +func (f *fileCache) decryptCache(cacheCipherText []byte) (values.Values, error) { + if len(cacheCipherText) < aes.BlockSize { + return nil, errors.New("ciphertext too short") + } + + iv := cacheCipherText[:aes.BlockSize] + cacheCipherText = cacheCipherText[aes.BlockSize:] + + cachePlainText := make([]byte, len(cacheCipherText)) + stream := cipher.NewOFB(f.cipherBlock, iv) + stream.XORKeyStream(cachePlainText, cacheCipherText) + + values := make(values.Values) + if err := json.Unmarshal(cachePlainText, &values); err != nil { + return nil, err + } + + return values, nil +} + +func (f *fileCache) encryptCache(values values.Values) ([]byte, error) { + cacheJSON, err := json.Marshal(values) + if err != nil { + return nil, err + } + + cacheCipherText := make([]byte, aes.BlockSize+len(cacheJSON)) + iv := cacheCipherText[:aes.BlockSize] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + stream := cipher.NewOFB(f.cipherBlock, iv) + stream.XORKeyStream(cacheCipherText[aes.BlockSize:], cacheJSON) + + return cacheCipherText, nil +} + +func (f fileCache) getFileStat(filename string) (os.FileInfo, error) { + return os.Stat(f.cacheBasePath + "/" + filename) +} + +func (f fileCache) readFile(filename string) ([]byte, error) { + return os.ReadFile(f.cacheBasePath + "/" + filename) +} + +func (f fileCache) writeToFile(filename string, data []byte) error { + file, err := os.Create(f.cacheBasePath + "/" + filename) + if err != nil { + return err + } + defer file.Close() + + if _, err := file.Write(data); err != nil { + return err + } + + return nil +} + +func init() { + const cachePath = "/yashiro" + + cacheDir, err := os.UserCacheDir() + if err != nil { + defaultCacheBasePath = "/tmp" + cachePath + "/cache" + return + } + defaultCacheBasePath = cacheDir + cachePath +} diff --git a/internal/client/cache/file_test.go b/internal/client/cache/file_test.go new file mode 100644 index 0000000..1e6e745 --- /dev/null +++ b/internal/client/cache/file_test.go @@ -0,0 +1,209 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cache + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "reflect" + "testing" + + "github.com/dwango/yashiro/internal/values" + "github.com/dwango/yashiro/pkg/config" +) + +func Test_newFileCache(t *testing.T) { + type args struct { + cfg config.FileCacheConfig + } + tests := []struct { + name string + args args + want Cache + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newFileCache(tt.args.cfg) + if (err != nil) != tt.wantErr { + t.Errorf("newFileCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("newFileCache() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_fileCache_SaveAndLoad(t *testing.T) { + block, _ := aes.NewCipher([]byte("0123456789abcdef0123456789abcdef")) + + type fields struct { + cacheBasePath string + cipherBlock cipher.Block + expired bool + } + type args struct { + in0 context.Context + val values.Values + } + tests := []struct { + name string + fields fields + args args + want values.Values + wantErr bool + }{ + { + name: "ok: save and load", + fields: fields{ + cacheBasePath: "testdata/save-and-load", + cipherBlock: block, + expired: true, + }, + args: args{ + in0: context.Background(), + val: values.Values{ + "key": "value", + }, + }, + want: values.Values{ + "key": "value", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fileCache{ + cacheBasePath: tt.fields.cacheBasePath, + cipherBlock: tt.fields.cipherBlock, + expired: tt.fields.expired, + } + if err := f.Save(tt.args.in0, tt.args.val); (err != nil) != tt.wantErr { + t.Errorf("fileCache.Save() error = %v, wantErr %v", err, tt.wantErr) + } + got, _, err := f.Load(tt.args.in0) + if (err != nil) != tt.wantErr { + t.Errorf("fileCache.Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("fileCache.Load() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_fileCache_readOrCreateKey(t *testing.T) { + type fields struct { + cacheBasePath string + cipherBlock cipher.Block + expired bool + } + tests := []struct { + name string + fields fields + wantErr bool + }{ + { + // This test case is executed only once. Therefore, if you want to retest, delete the file + // before executing it again. + name: "ok: create key", + fields: fields{ + cacheBasePath: "testdata/read-or-create-key", + }, + wantErr: false, + }, + { + name: "ok: read key", + fields: fields{ + cacheBasePath: "testdata/read-or-create-key", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fileCache{ + cacheBasePath: tt.fields.cacheBasePath, + cipherBlock: tt.fields.cipherBlock, + expired: tt.fields.expired, + } + if _, err := f.readOrCreateKey(); (err != nil) != tt.wantErr { + t.Errorf("fileCache.readOrCreateKey() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +func Test_fileCache_encryptAndDecryptCache(t *testing.T) { + block, _ := aes.NewCipher([]byte("0123456789abcdef0123456789abcdef")) + + type fields struct { + cacheBasePath string + cipherBlock cipher.Block + expired bool + } + type args struct { + values values.Values + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "ok: encrypt values", + fields: fields{ + cipherBlock: block, + }, + args: args{ + values: values.Values{ + "key": "value", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &fileCache{ + cacheBasePath: tt.fields.cacheBasePath, + cipherBlock: tt.fields.cipherBlock, + expired: tt.fields.expired, + } + gotEncrypt, err := f.encryptCache(tt.args.values) + if (err != nil) != tt.wantErr { + t.Errorf("fileCache.encryptCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + gotDecrypt, err := f.decryptCache(gotEncrypt) + if (err != nil) != tt.wantErr { + t.Errorf("fileCache.decryptCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotDecrypt, tt.args.values) { + t.Errorf("fileCache.decryptCache() = %v, want %v", gotDecrypt, tt.args.values) + } + }) + } +} diff --git a/internal/client/cache/memory.go b/internal/client/cache/memory.go new file mode 100644 index 0000000..ff6b3a9 --- /dev/null +++ b/internal/client/cache/memory.go @@ -0,0 +1,50 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cache + +import ( + "context" + "maps" + + "github.com/dwango/yashiro/internal/values" +) + +type memoryCache struct { + values values.Values +} + +func newMemoryCache() (Cache, error) { + return &memoryCache{}, nil +} + +// Load implements Cache. +func (m memoryCache) Load(_ context.Context) (values.Values, bool, error) { + expired := false + if len(m.values) == 0 { + expired = true + } + + return m.values, expired, nil +} + +// Save implements Cache. +func (m *memoryCache) Save(_ context.Context, val values.Values) error { + newVal := make(values.Values, len(val)) + maps.Copy(newVal, val) + m.values = newVal + + return nil +} diff --git a/internal/client/cache/memory_test.go b/internal/client/cache/memory_test.go new file mode 100644 index 0000000..95fdd10 --- /dev/null +++ b/internal/client/cache/memory_test.go @@ -0,0 +1,149 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cache + +import ( + "context" + "reflect" + "testing" + + "github.com/dwango/yashiro/internal/values" +) + +func Test_newMemoryCache(t *testing.T) { + tests := []struct { + name string + want Cache + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := newMemoryCache() + if (err != nil) != tt.wantErr { + t.Errorf("newMemoryCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("newMemoryCache() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_memoryCache_Load(t *testing.T) { + type fields struct { + values values.Values + } + type args struct { + in0 context.Context + } + tests := []struct { + name string + fields fields + args args + want values.Values + want1 bool + wantErr bool + }{ + { + name: "ok: get values", + fields: fields{ + values: values.Values{ + "key": "value", + }, + }, + args: args{ + in0: context.Background(), + }, + want: values.Values{ + "key": "value", + }, + }, + { + name: "ok: no values(return expired=true)", + fields: fields{ + values: nil, + }, + args: args{ + in0: context.Background(), + }, + want: nil, + want1: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := memoryCache{ + values: tt.fields.values, + } + got, got1, err := m.Load(tt.args.in0) + if (err != nil) != tt.wantErr { + t.Errorf("memoryCache.Load() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("memoryCache.Load() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("memoryCache.Load() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func Test_memoryCache_Save(t *testing.T) { + type fields struct { + values values.Values + } + type args struct { + in0 context.Context + val values.Values + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "ok: save values", + fields: fields{ + values: values.Values{}, + }, + args: args{ + in0: context.Background(), + val: values.Values{ + "key": "value", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &memoryCache{ + values: tt.fields.values, + } + if err := m.Save(tt.args.in0, tt.args.val); (err != nil) != tt.wantErr { + t.Errorf("memoryCache.Save() error = %v, wantErr %v", err, tt.wantErr) + } + if !reflect.DeepEqual(m.values, tt.args.val) { + t.Errorf("memoryCache.Save() got = %v, want %v", m.values, tt.args.val) + } + }) + } +} diff --git a/internal/client/cache/testdata/read-or-create-key/.key_hash b/internal/client/cache/testdata/read-or-create-key/.key_hash new file mode 100644 index 0000000..a23fbc4 --- /dev/null +++ b/internal/client/cache/testdata/read-or-create-key/.key_hash @@ -0,0 +1 @@ +$2a$05$mddC5bB9QF83.WYeYbbMLu2mkywrUIBs9Bi.FXkfCAcNLH8YN8UTW \ No newline at end of file diff --git a/internal/client/cache/testdata/read-or-create-key/key b/internal/client/cache/testdata/read-or-create-key/key new file mode 100644 index 0000000..59941e6 --- /dev/null +++ b/internal/client/cache/testdata/read-or-create-key/key @@ -0,0 +1 @@ +a-Céæ§ym¡Ëøgyl«ÇBÚ¸|I{—·è)žÇ \ No newline at end of file diff --git a/internal/client/cache/testdata/save-and-load/values b/internal/client/cache/testdata/save-and-load/values new file mode 100644 index 0000000..dfc9155 --- /dev/null +++ b/internal/client/cache/testdata/save-and-load/values @@ -0,0 +1 @@ +S!ª´? ´(‰1z¢Á IôDâÞ]Q*Z±y·HÏ \ No newline at end of file diff --git a/internal/client/cache_test.go b/internal/client/cache_test.go new file mode 100644 index 0000000..d30fc45 --- /dev/null +++ b/internal/client/cache_test.go @@ -0,0 +1,138 @@ +/** + * Copyright 2024 DWANGO Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package client + +import ( + "context" + "reflect" + "testing" + + "github.com/dwango/yashiro/internal/client/cache" + "github.com/dwango/yashiro/internal/values" +) + +func Test_newClientWithCache(t *testing.T) { + type args struct { + client Client + cache cache.Cache + } + tests := []struct { + name string + args args + want Client + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := newClientWithCache(tt.args.client, tt.args.cache); !reflect.DeepEqual(got, tt.want) { + t.Errorf("newClientWithCache() = %v, want %v", got, tt.want) + } + }) + } +} + +type mockClient func(ctx context.Context, ignoreNotFound bool) (values.Values, error) + +func (m mockClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + return m(ctx, ignoreNotFound) +} + +type mockCache func(ctx context.Context) (values.Values, bool, error) + +func (m mockCache) Load(ctx context.Context) (values.Values, bool, error) { + return m(ctx) +} +func (m mockCache) Save(ctx context.Context, val values.Values) error { + return nil +} + +func Test_clientWithCache_GetValues(t *testing.T) { + type fields struct { + client Client + cache cache.Cache + } + type args struct { + ctx context.Context + ignoreNotFound bool + } + tests := []struct { + name string + fields fields + args args + want values.Values + wantErr bool + }{ + { + name: "ok: get values from cache", + fields: fields{ + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + return values.Values{ + "key-client": "value-client", + }, nil + }), + cache: mockCache(func(ctx context.Context) (values.Values, bool, error) { + return values.Values{ + "key-cache": "value-cache", + }, false, nil + }), + }, + args: args{ + ctx: context.Background(), + ignoreNotFound: false, + }, + want: values.Values{ + "key-cache": "value-cache", + }, + }, + { + name: "ok: get values from client(no cache)", + fields: fields{ + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + return values.Values{ + "key-client": "value-client", + }, nil + }), + cache: mockCache(func(ctx context.Context) (values.Values, bool, error) { + return values.Values{}, true, nil + }), + }, + args: args{ + ctx: context.Background(), + ignoreNotFound: false, + }, + want: values.Values{ + "key-client": "value-client", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &clientWithCache{ + client: tt.fields.client, + cache: tt.fields.cache, + } + got, err := c.GetValues(tt.args.ctx, tt.args.ignoreNotFound) + if (err != nil) != tt.wantErr { + t.Errorf("clientWithCache.GetValues() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("clientWithCache.GetValues() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/client/client.go b/internal/client/client.go index 848cc93..652a837 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -20,26 +20,44 @@ import ( "context" "errors" + "github.com/dwango/yashiro/internal/client/cache" + "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" ) -// Defines errors +// Define errors var ( ErrNotfoundValueConfig = errors.New("not found value config") - ErrValueIsEmpty = errors.New("value is empty") - ErrInvalidJSON = errors.New("invalid json string") ) // Client is the external stores client. type Client interface { - GetValues(ctx context.Context, ignoreNotFound bool) (Values, error) + GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) } // New returns a new Client. func New(cfg *config.Config) (Client, error) { + var client Client + var err error + if cfg.Aws != nil { - return newAwsClient(cfg.Aws) + client, err = newAwsClient(cfg.Aws) + } + if err != nil { + return nil, err + } + + if cfg.Global.EnableCache { + cache, err := cache.New(cfg.Global.Cache) + if err != nil { + return nil, err + } + client = newClientWithCache(client, cache) + } + + if client == nil { + return nil, ErrNotfoundValueConfig } - return nil, ErrNotfoundValueConfig + return client, nil } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 0e66feb..b03d56b 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -20,9 +20,15 @@ import ( "fmt" "strings" + "github.com/dwango/yashiro/pkg/config" "github.com/spf13/cobra" ) +var ( + configFile string + globalConfig = &config.Config{} +) + // New returns a new cobra.Command. func New() *cobra.Command { cmd := &cobra.Command{ @@ -32,6 +38,9 @@ func New() *cobra.Command { SilenceErrors: true, } + f := cmd.PersistentFlags() + f.StringVarP(&configFile, "config", "c", config.DefaultConfigFilename, "specify config file.") + cmd.AddCommand(newTemplateCommand()) cmd.AddCommand(newVersionCommand()) @@ -49,3 +58,14 @@ func checkArgsLength(argsReceived int, requiredArgs ...string) error { } return nil } + +// preLoadConfig is PreRunE function for cobra.Command. This function preloads config file. +func preLoadConfig(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + if err := globalConfig.LoadFromFile(ctx, configFile); err != nil { + return err + } + + return nil +} diff --git a/internal/cmd/template.go b/internal/cmd/template.go index c66fbaa..537cf71 100644 --- a/internal/cmd/template.go +++ b/internal/cmd/template.go @@ -38,7 +38,7 @@ const example = ` # specify single file. ` var textTypeValues = []string{ - string(engine.TextTypePlane), + string(engine.TextTypePlain), string(engine.TextTypeJSON), string(engine.TextTypeJSONArray), string(engine.TextTypeYAML), @@ -47,7 +47,6 @@ var textTypeValues = []string{ } func newTemplateCommand() *cobra.Command { - var configFile string var ignoreNotFound bool var textType string @@ -58,15 +57,12 @@ func newTemplateCommand() *cobra.Command { Args: func(_ *cobra.Command, args []string) error { return checkArgsLength(len(args), "template file") }, + PreRunE: preLoadConfig, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - cfg, err := config.NewFromFile(ctx, configFile) - if err != nil { - return err - } - - eng, err := engine.New(cfg, + globalConfig.Global.Cache.Type = config.CacheTypeFile + eng, err := engine.New(globalConfig, engine.IgnoreNotFound(ignoreNotFound), engine.TextType(engine.TextTypeOpt(textType)), ) if err != nil { @@ -83,9 +79,10 @@ func newTemplateCommand() *cobra.Command { } f := cmd.Flags() - f.StringVarP(&configFile, "config", "c", config.DefaultConfigFilename, "specify config file.") - f.StringVar(&textType, "text-type", string(engine.TextTypePlane), - fmt.Sprintf("specify text type after rendering. available values: %s", strings.Join(textTypeValues, ", ")), + f.StringVar(&globalConfig.Global.Cache.File.CachePath, "cache-dir", "", "specify the directory to save the cache files.") + f.BoolVar(&globalConfig.Global.EnableCache, "enable-cache", false, "enable the file base cache.") + f.StringVar(&textType, "text-type", string(engine.TextTypePlain), + fmt.Sprintf("specify the text type after rendering. available values: %s", strings.Join(textTypeValues, ", ")), ) f.BoolVar(&ignoreNotFound, "ignore-not-found", false, "ignore values are not found in the external store.") diff --git a/internal/client/values.go b/internal/values/values.go similarity index 89% rename from internal/client/values.go rename to internal/values/values.go index db1b988..4265adf 100644 --- a/internal/client/values.go +++ b/internal/values/values.go @@ -13,15 +13,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package client +package values import ( "encoding/json" + "errors" "fmt" "github.com/dwango/yashiro/pkg/config" ) +// Define errors +var ( + ErrValueIsEmpty = errors.New("value is empty") + ErrInvalidJSON = errors.New("invalid json string") +) + // Values are stored values from external stores. type Values map[string]any diff --git a/internal/client/values_test.go b/internal/values/values_test.go similarity index 99% rename from internal/client/values_test.go rename to internal/values/values_test.go index 815778f..861b78e 100644 --- a/internal/client/values_test.go +++ b/internal/values/values_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package client +package values import ( "errors" diff --git a/pkg/config/config.go b/pkg/config/config.go index 4d027bc..97d442a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -30,7 +30,30 @@ const DefaultConfigFilename = "./yashiro.yaml" // Config is Yashiro configuration. type Config struct { - Aws *AwsConfig `json:"aws,omitempty"` + Global GlobalConfig `json:"global,omitempty"` + Aws *AwsConfig `json:"aws,omitempty"` +} + +type GlobalConfig struct { + EnableCache bool `json:"enable_cache"` + Cache CacheConfig `json:"cache,omitempty"` +} + +type CacheType string + +const ( + CacheTypeUnspecified CacheType = "" + CacheTypeMemory CacheType = "memory" // default + CacheTypeFile CacheType = "file" +) + +type CacheConfig struct { + Type CacheType `json:"type,omitempty"` + File FileCacheConfig `json:"file,omitempty"` +} + +type FileCacheConfig struct { + CachePath string `json:"cache_path,omitempty"` } // AwsConfig is AWS service configuration. @@ -54,28 +77,27 @@ type AwsParameterStoreValueConfig struct { Decryption *bool `json:"decryption,omitempty"` } -// NewFromFile returns a new Config according to a file. The configuration file is assumed to +// LoadFromFile sets Config values according to a file. The configuration file is assumed to // be in YAML format. -func NewFromFile(ctx context.Context, filename string) (*Config, error) { +func (c *Config) LoadFromFile(ctx context.Context, filename string) error { b, err := getConfigFile(filename) if err != nil { - return nil, err + return err } - cfg := &Config{} - if err := yaml.Unmarshal(b, cfg); err != nil { - return nil, err + if err := yaml.Unmarshal(b, &c); err != nil { + return err } - if cfg.Aws != nil { + if c.Aws != nil { awsCfg, err := awsconfig.LoadDefaultConfig(ctx) if err != nil { - return nil, err + return err } - cfg.Aws.SdkConfig = &awsCfg + c.Aws.SdkConfig = &awsCfg } - return cfg, nil + return nil } // Value is interface of external store value. diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 6d9fe1a..2497a34 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -51,7 +51,7 @@ func New(cfg *config.Config, option ...Option) (Engine, error) { } var encAndDec encoding.EncodeAndDecoder - if opts.TextType == TextTypePlane { + if opts.TextType == TextTypePlain { encAndDec = &noOpEncodeAndDecoder{} } else { encAndDec, err = encoding.NewEncodeAndDecoder(opts.TextType) diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index b38c0d9..b0136e5 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -24,6 +24,7 @@ import ( "text/template" "github.com/dwango/yashiro/internal/client" + "github.com/dwango/yashiro/internal/values" "github.com/dwango/yashiro/pkg/config" "github.com/dwango/yashiro/pkg/engine/encoding" ) @@ -55,9 +56,9 @@ func TestNew(t *testing.T) { } } -type mockClient func(ctx context.Context, ignoreNotFound bool) (client.Values, error) +type mockClient func(ctx context.Context, ignoreNotFound bool) (values.Values, error) -func (m mockClient) GetValues(ctx context.Context, ignoreNotFound bool) (client.Values, error) { +func (m mockClient) GetValues(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return m(ctx, ignoreNotFound) } @@ -88,7 +89,7 @@ func Test_engine_Render(t *testing.T) { { name: "ok: render", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: &noOpEncodeAndDecoder{}, @@ -104,7 +105,7 @@ func Test_engine_Render(t *testing.T) { { name: "ok: deep render", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"Values": map[string]any{"key": "value"}}, nil }), encodeAndDecoder: &noOpEncodeAndDecoder{}, @@ -120,7 +121,7 @@ func Test_engine_Render(t *testing.T) { { name: "ok: render with function", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: &noOpEncodeAndDecoder{}, @@ -136,7 +137,7 @@ func Test_engine_Render(t *testing.T) { { name: "ok: encode and decode as yaml-docs after rendering", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: createEncodeAndDecoder(encoding.TextTypeYAMLDocs), @@ -152,8 +153,8 @@ func Test_engine_Render(t *testing.T) { { name: "error: failed to get values", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { - return nil, client.ErrValueIsEmpty + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { + return nil, values.ErrValueIsEmpty }), encodeAndDecoder: &noOpEncodeAndDecoder{}, template: template.New("test"), @@ -168,7 +169,7 @@ func Test_engine_Render(t *testing.T) { { name: "error: failed to parse template", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: &noOpEncodeAndDecoder{}, @@ -184,7 +185,7 @@ func Test_engine_Render(t *testing.T) { { name: "error: failed to execute template", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: &noOpEncodeAndDecoder{}, @@ -200,7 +201,7 @@ func Test_engine_Render(t *testing.T) { { name: "error: failed to encode and decode", fields: fields{ - client: mockClient(func(ctx context.Context, ignoreNotFound bool) (client.Values, error) { + client: mockClient(func(ctx context.Context, ignoreNotFound bool) (values.Values, error) { return map[string]any{"key": "value"}, nil }), encodeAndDecoder: createEncodeAndDecoder(encoding.TextTypeJSON), diff --git a/pkg/engine/options.go b/pkg/engine/options.go index 90c207c..e84e105 100644 --- a/pkg/engine/options.go +++ b/pkg/engine/options.go @@ -24,7 +24,7 @@ type Option func(*opts) type TextTypeOpt = encoding.TextType const ( - TextTypePlane TextTypeOpt = "plane" + TextTypePlain TextTypeOpt = "plain" TextTypeJSON TextTypeOpt = encoding.TextTypeJSON TextTypeJSONArray TextTypeOpt = encoding.TextTypeJSONArray TextTypeYAML TextTypeOpt = encoding.TextTypeYAML @@ -53,5 +53,5 @@ type opts struct { var defaultOpts = &opts{ IgnoreNotFound: false, - TextType: TextTypePlane, + TextType: TextTypePlain, }