Skip to content

Commit

Permalink
Support custom certificates in metrics server
Browse files Browse the repository at this point in the history
Signed-off-by: Botond Szirtes <[email protected]>
  • Loading branch information
bszirtes committed Dec 11, 2024
1 parent 1e611de commit b21d0d3
Showing 1 changed file with 217 additions and 29 deletions.
246 changes: 217 additions & 29 deletions pkg/tools/prometheus/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,14 @@ package prometheus
import (
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"os"
"path/filepath"
"sync"
"time"

"github.com/fsnotify/fsnotify"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
Expand All @@ -31,52 +36,97 @@ import (
"github.com/networkservicemesh/sdk/pkg/tools/log"
)

// ListenAndServe gathers the certificate and initiates the server to begin handling incoming requests
func ListenAndServe(ctx context.Context, listenOn string, headerTimeout time.Duration, cancel context.CancelFunc) {
metricsServer := server{
ListenOn: listenOn,
HeaderTimeout: headerTimeout,
// Server is a server type for exposing Prometheus metrics
type Server struct {
tlsConfig *tls.Config
certHandler *certHandler
listenOn string
certFile string
keyFile string
caFile string
monitorCert bool
headerTimeout time.Duration
}

// Option is an option pattern for prometheus server
type Option func(s *Server)

// WithCustomCert sets the certificate and key to use for TLS
func WithCustomCert(certFile, keyFile string) Option {
return func(s *Server) {
s.certFile = certFile
s.keyFile = keyFile
}
}

tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
// WithCustomCA sets the CA file to use for mTLS
func WithCustomCA(caFile string) Option {
return func(s *Server) {
s.caFile = caFile
}
source, err := workloadapi.NewX509Source(ctx)
if err != nil {
log.FromContext(ctx).Fatalf("error getting x509 source: %v", err.Error())
}

// WithHeaderTimeout sets the header timeout for the prometheus server
func WithHeaderTimeout(headerTimeout time.Duration) Option {
return func(s *Server) {
s.headerTimeout = headerTimeout
}
tlsConfig.GetCertificate = tlsconfig.GetCertificate(source)
metricsServer.TLSConfig = tlsConfig
}

select {
case <-ctx.Done():
err = source.Close()
log.FromContext(ctx).Errorf("unable to close x509 source: %v", err.Error())
default:
// WithCertificateMonitoring enables monitoring for certificate renewals
func WithCertificateMonitoring(monitorCert bool) Option {
return func(s *Server) {
s.monitorCert = monitorCert
}
}

// NewServer creates a new prometheus server instance
func NewServer(listenOn string, options ...Option) *Server {
server := &Server{
listenOn: listenOn,
certFile: "",
keyFile: "",
caFile: "",
monitorCert: false,
headerTimeout: 5 * time.Second,
}
for _, opt := range options {
opt(server)
}

return server
}

// ListenAndServe gathers the certificate and initiates the Server to begin handling incoming requests
func (s *Server) ListenAndServe(ctx context.Context, cancel context.CancelFunc) {
log.FromContext(ctx).Debugf("new metrics server created with parameters listenOn: '%v', certFile: '%v', keyFile: '%v', caFile: '%v', headerTimeout: '%v', monitorCert: '%v'",
s.listenOn, s.certFile, s.keyFile, s.caFile, s.headerTimeout, s.monitorCert)

s.createTLSConfig(ctx)

if s.monitorCert && s.certFile != "" && s.keyFile != "" {
err := s.monitorCertificate(ctx)
if err != nil {
log.FromContext(ctx).Error(err.Error())
}
}

go func() {
err := metricsServer.start(ctx)
err := s.start(ctx)
if err != nil {
log.FromContext(ctx).Error(err.Error())
cancel()
}
}()
}

type server struct {
TLSConfig *tls.Config
ListenOn string
HeaderTimeout time.Duration
}

func (s *server) start(ctx context.Context) error {
log.FromContext(ctx).Info("Start metrics server on ", s.ListenOn)
func (s *Server) start(ctx context.Context) error {
log.FromContext(ctx).Info("start metrics server on ", s.listenOn)

server := &http.Server{
Addr: s.ListenOn,
TLSConfig: s.TLSConfig,
ReadHeaderTimeout: s.HeaderTimeout,
Addr: s.listenOn,
TLSConfig: s.tlsConfig,
ReadHeaderTimeout: s.headerTimeout,
}

http.Handle("/metrics", promhttp.Handler())
Expand Down Expand Up @@ -107,3 +157,141 @@ func (s *server) start(ctx context.Context) error {

return nil
}

func (s *Server) createTLSConfig(ctx context.Context) {
log.FromContext(ctx).Debug("create TLS config for metrics server")

s.tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS13,
}

if s.certFile != "" && s.keyFile != "" {
s.certHandler = &certHandler{}
err := s.certHandler.LoadCertificate(s.certFile, s.keyFile, s.caFile)
if err != nil {
log.FromContext(ctx).Fatalf("error loading custom certificate and key: %v", err)
}
s.tlsConfig.GetCertificate = s.certHandler.GetCertificate
if s.caFile != "" {
log.FromContext(ctx).Debug("enable client authentication for metrics server")

s.tlsConfig.ClientCAs = s.certHandler.caCertPool
s.tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
}
} else {
source, err := workloadapi.NewX509Source(ctx)
if err != nil {
log.FromContext(ctx).Fatalf("error getting x509 source: %v", err.Error())
}
s.tlsConfig.GetCertificate = tlsconfig.GetCertificate(source)

select {
case <-ctx.Done():
err = source.Close()
log.FromContext(ctx).Errorf("unable to close x509 source: %v", err.Error())
default:
}
}
}

type certHandler struct {
cert *tls.Certificate
caCertPool *x509.CertPool
mu sync.RWMutex
}

func (certHandler *certHandler) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
certHandler.mu.RLock()
defer certHandler.mu.RUnlock()
return certHandler.cert, nil
}

func (certHandler *certHandler) LoadCertificate(certFile, keyFile, caFile string) error {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return errors.Errorf("error loading custom certificate and key: %v", err)
}
certHandler.mu.Lock()
defer certHandler.mu.Unlock()
certHandler.cert = &cert

if caFile != "" {
err = certHandler.LoadCertificateAuthority(caFile)
if err != nil {
return errors.Errorf("error loading custom certificate: %v", err)
}
}
return nil
}

func (certHandler *certHandler) LoadCertificateAuthority(caFile string) error {
caCert, err := os.ReadFile(caFile)
if err != nil {
return errors.Errorf("failed to read CA certificate: %s", err)
}

certHandler.caCertPool = x509.NewCertPool()
ok := certHandler.caCertPool.AppendCertsFromPEM(caCert)
if !ok {
return errors.Errorf("failed to add CA certificate to the pool")
}

return nil
}

func (s *Server) monitorCertificate(ctx context.Context) error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return errors.Errorf("failed to create new watcher: %s", err)
}

certFolder := filepath.Dir(s.certFile)
certFileName := filepath.Join(certFolder, "..data")

go func() {
defer func() {
if e := watcher.Close(); e != nil {
log.FromContext(ctx).Errorf("error closing watcher: %v", e)
}
}()

for {
select {
case event, ok := <-watcher.Events:
if !ok {
log.FromContext(ctx).Error("certificate watcher event channel closed")
return
}
log.FromContext(ctx).Debugf("certificate watcher event: %v", event)
if event.Name == certFileName && event.Op&fsnotify.Create == fsnotify.Create {
log.FromContext(ctx).Debugf("certificate file '%s' was modified, reloading certificate", event.Name)
e := s.certHandler.LoadCertificate(s.certFile, s.keyFile, s.caFile)
if e != nil {
log.FromContext(ctx).Errorf("failed to reload metrics server certificate: %v", e)
} else {
log.FromContext(ctx).Info("metrics server certificate reloaded successfully")
}
}

case e, ok := <-watcher.Errors:
if !ok {
log.FromContext(ctx).Error("certificate watcher event channel closed", e)
return
}
log.FromContext(ctx).Errorf("certificate watcher error: %v", e)

case <-ctx.Done():
log.FromContext(ctx).Info("stopping certificate watcher due to context cancellation")
return
}
}
}()

err = watcher.Add(certFolder)
if err != nil {
log.FromContext(ctx).Errorf("failed to add certificate folder to file watcher: %v", err)
return err
}

return nil
}

0 comments on commit b21d0d3

Please sign in to comment.