-
Notifications
You must be signed in to change notification settings - Fork 54
/
main.go
211 lines (183 loc) · 7.97 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package main
import (
"context"
"flag"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault-csi-provider/internal/clientcache"
"github.com/hashicorp/vault-csi-provider/internal/config"
"github.com/hashicorp/vault-csi-provider/internal/hmac"
providerserver "github.com/hashicorp/vault-csi-provider/internal/server"
"github.com/hashicorp/vault-csi-provider/internal/version"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/utils/pointer"
pb "sigs.k8s.io/secrets-store-csi-driver/provider/v1alpha1"
)
const (
namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
)
func main() {
logger := hclog.Default()
err := realMain(logger)
if err != nil {
logger.Error("Error running provider", "err", err)
os.Exit(1)
}
}
func setupLogger(flags config.FlagsConfig) hclog.Logger {
logger := hclog.Default()
var level hclog.Level
if flags.LogLevel != "" {
level = hclog.LevelFromString(flags.LogLevel)
if level == hclog.NoLevel {
level = hclog.Info
}
} else if flags.Debug {
level = hclog.Debug
}
logger.SetLevel(level)
return logger
}
func realMain(logger hclog.Logger) error {
flags := config.FlagsConfig{}
flag.StringVar(&flags.Endpoint, "endpoint", "/tmp/vault.sock", "Path to socket on which to listen for driver gRPC calls.")
flag.BoolVar(&flags.Debug, "debug", false, "Sets log to debug level. This has been deprecated, please use -log-level=debug instead.")
flag.StringVar(&flags.LogLevel, "log-level", "info", "Sets log level. Options are info, debug, trace, warn, error, and off.")
flag.BoolVar(&flags.Version, "version", false, "Prints the version information.")
flag.StringVar(&flags.HealthAddr, "health-addr", ":8080", "Configure http listener for reporting health.")
flag.StringVar(&flags.HMACSecretName, "hmac-secret-name", "vault-csi-provider-hmac-key", "Configure the Kubernetes secret name that the provider creates to store an HMAC key for generating secret version hashes")
flag.IntVar(&flags.CacheSize, "cache-size", 1000, "Set the maximum number of Vault tokens that will be cached in-memory. One Vault token will be stored for each pod on the same node that mounts secrets.")
flag.StringVar(&flags.VaultAddr, "vault-addr", "", "Default address for connecting to Vault. Can also be specified via the VAULT_ADDR environment variable.")
flag.StringVar(&flags.VaultMount, "vault-mount", "kubernetes", "Default Vault mount path for authentication. Can refer to a Kubernetes or JWT auth mount.")
flag.StringVar(&flags.VaultNamespace, "vault-namespace", "", "Default Vault namespace for Vault requests. Can also be specified via the VAULT_NAMESPACE environment variable.")
flag.StringVar(&flags.TLSCACertPath, "vault-tls-ca-cert", "", "Path on disk to a single PEM-encoded CA certificate to trust for Vault. Takes precendence over -vault-tls-ca-directory. Can also be specified via the VAULT_CACERT environment variable.")
flag.StringVar(&flags.TLSCADirectory, "vault-tls-ca-directory", "", "Path on disk to a directory of PEM-encoded CA certificates to trust for Vault. Can also be specified via the VAULT_CAPATH environment variable.")
flag.StringVar(&flags.TLSServerName, "vault-tls-server-name", "", "Name to use as the SNI host when connecting to Vault via TLS. Can also be specified via the VAULT_TLS_SERVER_NAME environment variable.")
flag.StringVar(&flags.TLSClientCert, "vault-tls-client-cert", "", "Path on disk to a PEM-encoded client certificate for mTLS communication with Vault. If set, also requires -vault-tls-client-key. Can also be specified via the VAULT_CLIENT_CERT environment variable.")
flag.StringVar(&flags.TLSClientKey, "vault-tls-client-key", "", "Path on disk to a PEM-encoded client key for mTLS communication with Vault. If set, also requires -vault-tls-client-cert. Can also be specified via the VAULT_CLIENT_KEY environment variable.")
flag.BoolVar(&flags.TLSSkipVerify, "vault-tls-skip-verify", false, "Disable verification of TLS certificates. Can also be specified via the VAULT_SKIP_VERIFY environment variable.")
flag.Parse()
// set log level
logger = setupLogger(flags)
if flags.Version {
v, err := version.GetVersion()
if err != nil {
return fmt.Errorf("failed to print version, err: %w", err)
}
// print the version and exit
_, err = fmt.Println(v)
return err
}
logger.Info("Creating new gRPC server")
serverLogger := logger.Named("server")
server := grpc.NewServer(
grpc.UnaryInterceptor(func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
startTime := time.Now()
serverLogger.Info("Processing unary gRPC call", "grpc.method", info.FullMethod)
resp, err := handler(ctx, req)
serverLogger.Info("Finished unary gRPC call", "grpc.method", info.FullMethod, "grpc.time", time.Since(startTime), "grpc.code", status.Code(err), "err", err)
return resp, err
}),
)
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-c
logger.Info(fmt.Sprintf("Caught signal %s, shutting down", sig))
server.GracefulStop()
}()
listener, err := listen(logger, flags.Endpoint)
if err != nil {
return err
}
defer listener.Close()
cfg, err := rest.InClusterConfig()
if err != nil {
return err
}
clientset, err := kubernetes.NewForConfig(cfg)
if err != nil {
return err
}
namespace, err := os.ReadFile(namespaceFile)
if err != nil {
return fmt.Errorf("failed to read namespace from file: %w", err)
}
hmacSecretSpec := &corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: flags.HMACSecretName,
Namespace: string(namespace),
// TODO: Configurable labels and annotations?
},
Immutable: pointer.Bool(true),
}
hmacGenerator := hmac.NewHMACGenerator(clientset, hmacSecretSpec)
clientCache, err := clientcache.NewClientCache(serverLogger.Named("vaultclient"), flags.CacheSize)
if err != nil {
return fmt.Errorf("failed to initialize the cache: %w", err)
}
srv := providerserver.NewServer(serverLogger, flags, clientset, hmacGenerator, clientCache)
pb.RegisterCSIDriverProviderServer(server, srv)
// Create health handler
mux := http.NewServeMux()
ms := http.Server{
Addr: flags.HealthAddr,
Handler: mux,
}
defer func() {
err := ms.Shutdown(context.Background())
if err != nil {
logger.Error("Error shutting down health handler", "err", err)
}
}()
mux.HandleFunc("/health/ready", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Start health handler
go func() {
logger.Info("Starting health handler", "addr", flags.HealthAddr)
if err := ms.ListenAndServe(); err != nil && err != http.ErrServerClosed {
logger.Error("Error with health handler", "error", err)
}
}()
logger.Info("Starting gRPC server")
err = server.Serve(listener)
if err != nil {
return fmt.Errorf("error running gRPC server: %w", err)
}
return nil
}
func listen(logger hclog.Logger, endpoint string) (net.Listener, error) {
// Because the unix socket is created in a host volume (i.e. persistent
// storage), it can persist from previous runs if the pod was not terminated
// cleanly. Check if we need to clean up before creating a listener.
_, err := os.Stat(endpoint)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to check for existence of unix socket: %w", err)
} else if err == nil {
logger.Info("Cleaning up pre-existing file at unix socket location", "endpoint", endpoint)
err = os.Remove(endpoint)
if err != nil {
return nil, fmt.Errorf("failed to clean up pre-existing file at unix socket location: %w", err)
}
}
logger.Info("Opening unix socket", "endpoint", endpoint)
listener, err := net.Listen("unix", endpoint)
if err != nil {
return nil, fmt.Errorf("failed to listen on unix socket at %s: %v", endpoint, err)
}
return listener, nil
}