Skip to content

Commit

Permalink
Implement tests for YC KMS provider and credentials forward
Browse files Browse the repository at this point in the history
gRPC calls `Encrypt` and `Decrypt` are mocked with dummy responses using base64 instead of actual encryption.

Since YC KMS responce with bynary data we are storing it as base64, together with mocked server where we encode it one more time instead of actual encryption we are using double encoding and double decoding in tests cases.
  • Loading branch information
astreter authored and kuzaxak committed Dec 26, 2024
1 parent 63bd7a9 commit 9904363
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 19 deletions.
78 changes: 59 additions & 19 deletions yckms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
ycsdk "github.com/yandex-cloud/go-sdk"
"github.com/yandex-cloud/go-sdk/gen/kmscrypto"
"github.com/yandex-cloud/go-sdk/iamkey"
"google.golang.org/grpc"
"os"
"strings"
"time"
Expand Down Expand Up @@ -40,17 +41,24 @@ type MasterKey struct {
// CreationDate is the creation timestamp of the MasterKey. Used
// for NeedsRotation.
CreationDate time.Time

credentials ycsdk.Credentials

// grpcConn can be used to inject a custom YC KMS client connection.
// Mostly useful for testing at present, to wire the client to a mock
// server.
grpcConn *grpc.ClientConn
}

func (key *MasterKey) TypeToIdentifier() string {
return KeyTypeIdentifier
}

func NewMasterKeyFromKeyID(keyID string) *MasterKey {
k := &MasterKey{}
k.KeyID = keyID
k.CreationDate = time.Now().UTC()
return k
return &MasterKey{
KeyID: keyID,
CreationDate: time.Now().UTC(),
}
}

func NewMasterKeyFromKeyIDString(keyID string) []*MasterKey {
Expand All @@ -59,18 +67,34 @@ func NewMasterKeyFromKeyIDString(keyID string) []*MasterKey {
return keys
}
for _, s := range strings.Split(keyID, ",") {
keys = append(keys, NewMasterKeyFromKeyID(s))
keys = append(keys, NewMasterKeyFromKeyID(strings.TrimSpace(s)))
}
return keys
}

func (key *MasterKey) Encrypt(dataKey []byte) error {
client, ctx, err := key.newKMSClient()
// YCCredentials is a ycsdk.Credentials used for authenticating towards YC KMS
type YCCredentials struct {
credentials ycsdk.Credentials
}

// NewYCCredentials creates a new YCCredentials with the provided ycsdk.Credentials.
func NewYCCredentials(credentials ycsdk.Credentials) *YCCredentials {
return &YCCredentials{credentials: credentials}
}

// ApplyToMasterKey configures the TokenCredential on the provided key.
func (c YCCredentials) ApplyToMasterKey(key *MasterKey) {
key.credentials = c.credentials
}

func (key *MasterKey) Encrypt(dataKey []byte) (err error) {
client, err := key.newKMSClient()
if err != nil {
log.WithError(err).WithField("keyID", key.KeyID).Error("Encryption failed")
return fmt.Errorf("cannot create YC KMS service: %w", err)
}
ciphertextResponse, err := client.Encrypt(ctx, &yckms.SymmetricEncryptRequest{

ciphertextResponse, err := client.Encrypt(context.Background(), &yckms.SymmetricEncryptRequest{
KeyId: key.KeyID,
Plaintext: dataKey,
})
Expand Down Expand Up @@ -105,13 +129,14 @@ func (key *MasterKey) EncryptIfNeeded(dataKey []byte) error {
// Decrypt decrypts the EncryptedKey field with YC KMS and returns
// the result.
func (key *MasterKey) Decrypt() ([]byte, error) {
client, ctx, err := key.newKMSClient()
client, err := key.newKMSClient()
if err != nil {
log.WithError(err).WithField("keyID", key.KeyID).Error("Decryption failed")
return nil, fmt.Errorf("cannot create YC KMS service: %w", err)
}

decodedCipher, err := base64.StdEncoding.DecodeString(string(key.EncryptedDataKey()))
plaintextResponse, err := client.Decrypt(ctx, &yckms.SymmetricDecryptRequest{
plaintextResponse, err := client.Decrypt(context.Background(), &yckms.SymmetricDecryptRequest{
KeyId: key.KeyID,
Ciphertext: decodedCipher,
})
Expand All @@ -134,7 +159,7 @@ func (key *MasterKey) ToString() string {
}

// ToMap converts the MasterKey to a map for serialization purposes.
func (key MasterKey) ToMap() map[string]interface{} {
func (key *MasterKey) ToMap() map[string]interface{} {
out := make(map[string]interface{})
out["key_id"] = key.KeyID
out["created_at"] = key.CreationDate.UTC().Format(time.RFC3339)
Expand All @@ -146,21 +171,36 @@ func (key MasterKey) ToMap() map[string]interface{} {
// and/or grpcConn, falling back to environmental defaults.
// It returns an error if the ResourceID is invalid, or if the setup of the
// client fails.
func (key *MasterKey) newKMSClient() (*kms.SymmetricCryptoServiceClient, context.Context, error) {
ctx := context.Background()
cred, err := getYandexCloudCredentials()
if err != nil {
return nil, nil, err
func (key *MasterKey) newKMSClient() (*kms.SymmetricCryptoServiceClient, error) {
var (
cred ycsdk.Credentials
err error
)

switch {
case key.credentials != nil:
cred = key.credentials
default:
cred, err = getYandexCloudCredentials()
if err != nil {
return nil, err
}
}

client, err := ycsdk.Build(ctx, ycsdk.Config{
client, err := ycsdk.Build(context.Background(), ycsdk.Config{
Credentials: cred,
})
if err != nil {
return nil, nil, err
return nil, err
}

if key.grpcConn != nil {
return kms.NewKMSCrypto(func(ctx context.Context) (*grpc.ClientConn, error) {
return key.grpcConn, nil
}).SymmetricCrypto(), nil
}

return client.KMSCrypto().SymmetricCrypto(), ctx, nil
return client.KMSCrypto().SymmetricCrypto(), nil
}

// getYandexCloudCredentials trying to locate credentials in the following order
Expand Down
177 changes: 177 additions & 0 deletions yckms/keysource_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package yckms

import (
"context"
"encoding/base64"
yckms "github.com/yandex-cloud/go-genproto/yandex/cloud/kms/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

const (
dummyKey = "920aff2e-c5f1-4040-943a-047fa387b27e"
anotherDummyKey = "920aff2e-c5f1-4040-943a-047fa587b27e"
dummyKeys = dummyKey + ", " + anotherDummyKey
decodedKey = "I want to be a DJ"
)

const bufSize = 1024 * 1024

var lis *bufconn.Listener

func init() {
lis = bufconn.Listen(bufSize)
s := grpc.NewServer()
yckms.RegisterSymmetricCryptoServiceServer(s, mockSymmetricCryptoServiceServer{})
go func() {
if err := s.Serve(lis); err != nil {
log.Fatalf("Server exited with error: %v", err)
}
}()
}

func bufDialer(context.Context, string) (net.Conn, error) {
return lis.Dial()
}

type mockSymmetricCryptoServiceServer struct {
}

func (mockSymmetricCryptoServiceServer) Encrypt(ctx context.Context, req *yckms.SymmetricEncryptRequest) (*yckms.SymmetricEncryptResponse, error) {
return &yckms.SymmetricEncryptResponse{
Ciphertext: []byte(base64.StdEncoding.EncodeToString(req.Plaintext)),
}, nil
}

func (mockSymmetricCryptoServiceServer) Decrypt(ctx context.Context, req *yckms.SymmetricDecryptRequest) (*yckms.SymmetricDecryptResponse, error) {
plain, err := base64.StdEncoding.DecodeString(string(req.Ciphertext))
if err != nil {
return nil, err
}
return &yckms.SymmetricDecryptResponse{
Plaintext: plain,
}, nil
}
func (mockSymmetricCryptoServiceServer) ReEncrypt(context.Context, *yckms.SymmetricReEncryptRequest) (*yckms.SymmetricReEncryptResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ReEncrypt not implemented")
}
func (mockSymmetricCryptoServiceServer) GenerateDataKey(context.Context, *yckms.GenerateDataKeyRequest) (*yckms.GenerateDataKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GenerateDataKey not implemented")
}

func TestNewMasterKeyFromKeyID(t *testing.T) {
key := NewMasterKeyFromKeyID(dummyKey)
assert.Equal(t, dummyKey, key.KeyID)
assert.NotNil(t, key.CreationDate)
}

func TestNewMasterKeyFromKeyIDString(t *testing.T) {
keys := NewMasterKeyFromKeyIDString(dummyKeys)
assert.Len(t, keys, 2)

k1 := keys[0]
k2 := keys[1]

assert.Equal(t, dummyKey, k1.KeyID)
assert.Equal(t, anotherDummyKey, k2.KeyID)
}

func TestMasterKey_Encrypt(t *testing.T) {
t.Run("encrypt", func(t *testing.T) {
grpcConn, err := createMockGRPCClient()
assert.NoError(t, err)

key := &MasterKey{
grpcConn: grpcConn,
}

dataKey := []byte(decodedKey)
err = key.Encrypt(dataKey)
assert.NoError(t, err)

// Double base64 is used because encrypted data stored as base64
// and our mock uses base64 instead of actual encryption
assert.EqualValues(t, base64.StdEncoding.EncodeToString([]byte(base64.StdEncoding.EncodeToString([]byte(decodedKey)))), key.EncryptedDataKey())
})
}

func TestMasterKey_Decrypt(t *testing.T) {
t.Run("decrypt", func(t *testing.T) {
grpcConn, err := createMockGRPCClient()
assert.NoError(t, err)

// Double base64 is used because encrypted data stored as base64
// and our mock uses base64 instead of actual encryption
key := &MasterKey{
EncryptedKey: base64.StdEncoding.EncodeToString([]byte(base64.StdEncoding.EncodeToString([]byte(decodedKey)))),
grpcConn: grpcConn,
}

got, err := key.Decrypt()
assert.NoError(t, err)
assert.Equal(t, []byte(decodedKey), got)
})
}

func TestMasterKey_EncryptedDataKey(t *testing.T) {
key := &MasterKey{EncryptedKey: "some key"}
assert.EqualValues(t, key.EncryptedKey, key.EncryptedDataKey())
}

func TestMasterKey_SetEncryptedDataKey(t *testing.T) {
key := &MasterKey{}
data := []byte("some data")
key.SetEncryptedDataKey(data)
assert.EqualValues(t, data, key.EncryptedKey)
}

func TestMasterKey_NeedsRotation(t *testing.T) {
t.Run("false", func(t *testing.T) {
k := &MasterKey{}
k.CreationDate = time.Now().UTC()

assert.False(t, k.NeedsRotation())
})

t.Run("true", func(t *testing.T) {
k := &MasterKey{}
k.CreationDate = time.Now().UTC().Add(-kmsTTL - 1)

assert.True(t, k.NeedsRotation())
})
}

func TestMasterKey_ToString(t *testing.T) {
key := NewMasterKeyFromKeyID(dummyKey)

assert.Equal(t, dummyKey, key.ToString())
}

func TestMasterKey_ToMap(t *testing.T) {
key := NewMasterKeyFromKeyID(dummyKey)

data := []byte("some data")
key.SetEncryptedDataKey(data)

res := key.ToMap()
assert.Equal(t, dummyKey, res["key_id"])
assert.Equal(t, key.CreationDate.UTC().Format(time.RFC3339), res["created_at"])
assert.Equal(t, "some data", res["enc"])
}

func createMockGRPCClient() (*grpc.ClientConn, error) {
return grpc.DialContext(
context.Background(),
"bufnet",
grpc.WithContextDialer(bufDialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}

0 comments on commit 9904363

Please sign in to comment.