Skip to content

Commit

Permalink
chore: improves security by updating all find containers with user sc…
Browse files Browse the repository at this point in the history
…ope (#3464)
  • Loading branch information
amir20 authored Dec 16, 2024
1 parent 3adf560 commit 23ba1f6
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 255 deletions.
384 changes: 202 additions & 182 deletions internal/agent/pb/rpc.pb.go

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions internal/agent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ func (s *server) StreamLogs(in *pb.StreamLogsRequest, out pb.AgentService_Stream
since = in.Since.AsTime()
}

reader, err := s.client.ContainerLogs(out.Context(), in.ContainerId, since, docker.StdType(in.StreamTypes))
container, err := s.store.FindContainer(in.ContainerId, docker.ContainerFilter{})
if err != nil {
return err
}

container, err := s.store.FindContainer(in.ContainerId)
reader, err := s.client.ContainerLogs(out.Context(), in.ContainerId, since, docker.StdType(in.StreamTypes))
if err != nil {
return err
}
Expand Down Expand Up @@ -172,7 +172,14 @@ func (s *server) StreamStats(in *pb.StreamStatsRequest, out pb.AgentService_Stre
}

func (s *server) FindContainer(ctx context.Context, in *pb.FindContainerRequest) (*pb.FindContainerResponse, error) {
container, err := s.store.FindContainer(in.ContainerId)
filter := make(docker.ContainerFilter)
if in.GetFilter() != nil {
for k, v := range in.GetFilter() {
filter[k] = append(filter[k], v.GetValues()...)
}
}

container, err := s.store.FindContainer(in.ContainerId, filter)
if err != nil {
return nil, status.Error(codes.NotFound, err.Error())
}
Expand Down
21 changes: 18 additions & 3 deletions internal/docker/container_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,26 @@ func (s *ContainerStore) ListContainers(filter ContainerFilter) ([]Container, er
return containers, nil
}

func (s *ContainerStore) FindContainer(id string) (Container, error) {
func (s *ContainerStore) FindContainer(id string, filter ContainerFilter) (Container, error) {
s.wg.Wait()
container, ok := s.containers.Load(id)

if ok {
if filter.Exists() {
validContainers, err := s.client.ListContainers(s.ctx, filter)
if err != nil {
return Container{}, err
}

validIDMap := lo.KeyBy(validContainers, func(item Container) string {
return item.ID
})

if _, ok := validIDMap[id]; !ok {
log.Warn().Str("id", id).Msg("user doesn't have access to container")
return Container{}, ErrContainerNotFound
}
}

if container, ok := s.containers.Load(id); ok {
return *container, nil
} else {
log.Warn().Str("id", id).Msg("container not found")
Expand Down
2 changes: 1 addition & 1 deletion internal/support/docker/agent_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func NewAgentService(client *agent.Client) ClientService {
}
}

func (a *agentService) FindContainer(ctx context.Context, id string) (docker.Container, error) {
func (a *agentService) FindContainer(ctx context.Context, id string, filter docker.ContainerFilter) (docker.Container, error) {
return a.client.FindContainer(ctx, id)
}

Expand Down
6 changes: 3 additions & 3 deletions internal/support/docker/client_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

type ClientService interface {
FindContainer(ctx context.Context, id string) (docker.Container, error)
FindContainer(ctx context.Context, id string, filter docker.ContainerFilter) (docker.Container, error)
ListContainers(ctx context.Context, filter docker.ContainerFilter) ([]docker.Container, error)
Host(ctx context.Context) (docker.Host, error)
ContainerAction(ctx context.Context, container docker.Container, action docker.ContainerAction) error
Expand Down Expand Up @@ -70,8 +70,8 @@ func (d *dockerClientService) StreamLogs(ctx context.Context, container docker.C
}
}

func (d *dockerClientService) FindContainer(ctx context.Context, id string) (docker.Container, error) {
return d.store.FindContainer(id)
func (d *dockerClientService) FindContainer(ctx context.Context, id string, filter docker.ContainerFilter) (docker.Container, error) {
return d.store.FindContainer(id, filter)
}

func (d *dockerClientService) ContainerAction(ctx context.Context, container docker.Container, action docker.ContainerAction) error {
Expand Down
4 changes: 2 additions & 2 deletions internal/support/docker/multi_host_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ func NewMultiHostService(manager ClientManager, timeout time.Duration) *MultiHos
return m
}

func (m *MultiHostService) FindContainer(host string, id string) (*containerService, error) {
func (m *MultiHostService) FindContainer(host string, id string, filter docker.ContainerFilter) (*containerService, error) {
client, ok := m.manager.Find(host)
if !ok {
return nil, fmt.Errorf("host %s not found", host)
}
ctx, cancel := context.WithTimeout(context.Background(), m.timeout)
defer cancel()
container, err := client.FindContainer(ctx, id)
container, err := client.FindContainer(ctx, id, filter)
if err != nil {
return nil, err
}
Expand Down
18 changes: 8 additions & 10 deletions internal/web/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"
"time"

"github.com/amir20/dozzle/internal/auth"
"github.com/amir20/dozzle/internal/docker"
"github.com/go-chi/chi/v5"
"github.com/rs/zerolog/log"
Expand All @@ -14,18 +15,15 @@ func (h *handler) containerActions(w http.ResponseWriter, r *http.Request) {
action := chi.URLParam(r, "action")
id := chi.URLParam(r, "id")

validIdMap, err := h.validContainerIDsForHost(r, hostKey(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if _, ok := validIdMap[id]; !ok {
http.Error(w, "container not found", http.StatusUnauthorized)
return
usersFilter := h.config.Filter
if h.config.Authorization.Provider != NONE {
user := auth.UserFromContext(r.Context())
if user.ContainerFilter.Exists() {
usersFilter = user.ContainerFilter
}
}

containerService, err := h.multiHostService.FindContainer(hostKey(r), id)
containerService, err := h.multiHostService.FindContainer(hostKey(r), id, usersFilter)
if err != nil {
log.Error().Err(err).Msg("error while trying to find container")
http.Error(w, err.Error(), http.StatusNotFound)
Expand Down
2 changes: 1 addition & 1 deletion internal/web/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func Test_handler_containerActions_unknown_container(t *testing.T) {

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
assert.Equal(t, 401, rr.Code)
assert.Equal(t, 404, rr.Code)
}

func Test_handler_containerActions_start(t *testing.T) {
Expand Down
63 changes: 14 additions & 49 deletions internal/web/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ import (
"github.com/docker/docker/pkg/stdcopy"
"github.com/dustin/go-humanize"
"github.com/go-chi/chi/v5"
"github.com/samber/lo"

"github.com/rs/zerolog/log"
)

func (h *handler) validContainerIDsForHost(r *http.Request, host string) (map[string]docker.Container, error) {
func (h *handler) downloadLogs(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")

usersFilter := h.config.Filter
if h.config.Authorization.Provider != NONE {
user := auth.UserFromContext(r.Context())
Expand All @@ -40,33 +41,7 @@ func (h *handler) validContainerIDsForHost(r *http.Request, host string) (map[st
}
}

validContainers, err := h.multiHostService.ListContainersForHost(host, usersFilter)
if err != nil {
return nil, err
}

validIdMap := lo.KeyBy(validContainers, func(item docker.Container) string {
return item.ID
})

return validIdMap, nil
}

func (h *handler) downloadLogs(w http.ResponseWriter, r *http.Request) {
id := chi.URLParam(r, "id")

validIdMap, err := h.validContainerIDsForHost(r, hostKey(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if _, ok := validIdMap[id]; !ok {
http.Error(w, "container not found", http.StatusUnauthorized)
return
}

containerService, err := h.multiHostService.FindContainer(hostKey(r), id)
containerService, err := h.multiHostService.FindContainer(hostKey(r), id, usersFilter)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down Expand Up @@ -137,18 +112,15 @@ func (h *handler) fetchLogsBetweenDates(w http.ResponseWriter, r *http.Request)
return
}

validIdMap, err := h.validContainerIDsForHost(r, hostKey(r))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}

if _, ok := validIdMap[id]; !ok {
http.Error(w, "container not found", http.StatusUnauthorized)
return
usersFilter := h.config.Filter
if h.config.Authorization.Provider != NONE {
user := auth.UserFromContext(r.Context())
if user.ContainerFilter.Exists() {
usersFilter = user.ContainerFilter
}
}

containerService, err := h.multiHostService.FindContainer(hostKey(r), id)
containerService, err := h.multiHostService.FindContainer(hostKey(r), id, usersFilter)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
Expand Down Expand Up @@ -376,8 +348,7 @@ func (h *handler) streamLogsForContainers(w http.ResponseWriter, r *http.Request
events := make([]*docker.LogEvent, 0)
stillRunning := false
for _, container := range existingContainers {

containerService, err := h.multiHostService.FindContainer(container.Host, container.ID)
containerService, err := h.multiHostService.FindContainer(container.Host, container.ID, usersFilter)

if err != nil {
log.Error().Err(err).Msg("error while finding container")
Expand Down Expand Up @@ -424,7 +395,7 @@ func (h *handler) streamLogsForContainers(w http.ResponseWriter, r *http.Request
}

streamLogs := func(container docker.Container) {
containerService, err := h.multiHostService.FindContainer(container.Host, container.ID)
containerService, err := h.multiHostService.FindContainer(container.Host, container.ID, usersFilter)
if err != nil {
log.Error().Err(err).Msg("error while finding container")
return
Expand Down Expand Up @@ -465,13 +436,7 @@ loop:
}
sseWriter.Message(logEvent)
case container := <-newContainers:
validIdMap, err := h.validContainerIDsForHost(r, container.Host)
if err != nil {
log.Error().Err(err).Msg("error fetching valid container IDs")
continue
}

if _, ok := validIdMap[container.ID]; ok {
if _, err := h.multiHostService.FindContainer(container.Host, container.ID, usersFilter); err == nil {
events <- &docker.ContainerEvent{ActorID: container.ID, Name: "container-started", Host: container.Host}
go streamLogs(container)
}
Expand Down
5 changes: 4 additions & 1 deletion protos/rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ message RepeatedString { repeated string values = 1; }

message ListContainersResponse { repeated Container containers = 1; }

message FindContainerRequest { string containerId = 1; }
message FindContainerRequest {
string containerId = 1;
map<string, RepeatedString> filter = 2;
}

message FindContainerResponse { Container container = 1; }

Expand Down

0 comments on commit 23ba1f6

Please sign in to comment.