Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ require (
github.com/perimeterx/marshmallow v1.1.5 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/thanhpk/randstr v1.0.6 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sync v0.7.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/thanhpk/randstr v1.0.6 h1:psAOktJFD4vV9NEVb3qkhRSMvYh4ORRaj1+w/hn4B+o=
github.com/thanhpk/randstr v1.0.6/go.mod h1:M/H2P1eNLZzlDwAzpkkkUvoyNNMbzRGhESZuEQk3r0U=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/vincent-petithory/dataurl v1.0.0 h1:cXw+kPto8NLuJtlMsI152irrVw9fRDX8AbShPRpg2CI=
Expand Down
11 changes: 7 additions & 4 deletions worker/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ const (

type RunnerContainer struct {
RunnerContainerConfig

Client *ClientWithResponses
Name string
Capacity int
Client *ClientWithResponses
}

type RunnerEndpoint struct {
Expand All @@ -39,7 +40,7 @@ type RunnerContainerConfig struct {
containerTimeout time.Duration
}

func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*RunnerContainer, error) {
func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig, name string) (*RunnerContainer, error) {
// Ensure that timeout is set to a non-zero value.
timeout := cfg.containerTimeout
if timeout == 0 {
Expand All @@ -61,7 +62,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*Runner
return nil, err
}

cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout)
cctx, cancel := context.WithTimeout(context.Background(), cfg.containerTimeout)
if err := runnerWaitUntilReady(cctx, client, pollingInterval); err != nil {
cancel()
return nil, err
Expand All @@ -70,6 +71,8 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*Runner

return &RunnerContainer{
RunnerContainerConfig: cfg,
Name: name,
Capacity: 1,
Client: client,
}, nil
}
Expand Down
49 changes: 25 additions & 24 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/client"
"github.com/docker/go-connections/nat"
"github.com/thanhpk/randstr"
)

const containerModelDir = "/models"
Expand All @@ -31,9 +32,10 @@ const containerCreator = "ai-worker"
// using the GPU we stop it so we don't have to worry about having enough ports
var containerHostPorts = map[string]string{
"text-to-image": "8000",
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
"image-to-image": "8100",
"image-to-video": "8200",
"upscale": "8300",
"audio-to-text": "8400",
}

type DockerManager struct {
Expand Down Expand Up @@ -107,41 +109,40 @@ func (m *DockerManager) Stop(ctx context.Context) error {
func (m *DockerManager) Borrow(ctx context.Context, pipeline, modelID string) (*RunnerContainer, error) {
m.mu.Lock()
defer m.mu.Unlock()

containerName := dockerContainerName(pipeline, modelID)
rc, ok := m.containers[containerName]
if !ok {
// The container does not exist so try to create it
var err error
// TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays.
rc, err = m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{})
if err != nil {
return nil, err
for _, runner := range m.containers {
if runner.Pipeline == pipeline && runner.ModelID == modelID {
delete(m.containers, runner.Name)
return runner, nil
}
}

// Remove container so it is unavailable until Return() is called
delete(m.containers, containerName)
// The container does not exist so try to create it
var err error
// TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays.
rc, err := m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{})
if err != nil {
return nil, err
}
return rc, nil

}

func (m *DockerManager) Return(rc *RunnerContainer) {
m.mu.Lock()
defer m.mu.Unlock()
m.containers[dockerContainerName(rc.Pipeline, rc.ModelID)] = rc
m.containers[rc.Name] = rc
}

// HasCapacity checks if an unused managed container exists or if a GPU is available for a new container.
func (m *DockerManager) HasCapacity(ctx context.Context, pipeline, modelID string) bool {
containerName := dockerContainerName(pipeline, modelID)

m.mu.Lock()
defer m.mu.Unlock()

// Check if unused managed container exists for the requested model.
_, ok := m.containers[containerName]
if ok {
return true
for _, rc := range m.containers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
return true
}
}

// Check for available GPU to allocate for a new container for the requested model.
Expand Down Expand Up @@ -185,7 +186,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
gpuOpts := opts.GpuOpts{}
gpuOpts.Set("device=" + gpu)

containerHostPort := containerHostPorts[pipeline]
containerHostPort := containerHostPorts[pipeline][:3] + gpu
hostConfig := &container.HostConfig{
Resources: container.Resources{
DeviceRequests: gpuOpts.Value(),
Expand Down Expand Up @@ -248,7 +249,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
containerTimeout: runnerContainerTimeout,
}

rc, err := NewRunnerContainer(ctx, cfg)
rc, err := NewRunnerContainer(ctx, cfg, containerName)
if err != nil {
dockerRemoveContainer(m.dockerClient, resp.ID)
return nil, err
Expand Down Expand Up @@ -311,7 +312,7 @@ func removeExistingContainers(ctx context.Context, client *client.Client) error
func dockerContainerName(pipeline string, modelID string) string {
// text-to-image, stabilityai/sd-turbo -> text-to-image_stabilityai_sd-turbo
// image-to-video, stabilityai/stable-video-diffusion-img2vid-xt -> image-to-video_stabilityai_stable-video-diffusion-img2vid-xt
return strings.ReplaceAll(pipeline+"_"+modelID, "/", "_")
return strings.ReplaceAll(pipeline+"_"+modelID+"_"+randstr.String(10), "/", "_")
}

func dockerRemoveContainer(client *client.Client, containerID string) error {
Expand Down
52 changes: 38 additions & 14 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,13 +265,19 @@ func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endp
Endpoint: endpoint,
containerTimeout: externalContainerTimeout,
}
rc, err := NewRunnerContainer(ctx, cfg)

name := dockerContainerName(pipeline, modelID)
if endpoint.URL != "" {
name = cfg.Endpoint.URL
slog.Info("name of container: ", slog.String("url", cfg.Endpoint.URL))
}

rc, err := NewRunnerContainer(ctx, cfg, name)
if err != nil {
return err
}

name := dockerContainerName(pipeline, modelID)
slog.Info("Starting external container", slog.String("name", name), slog.String("modelID", modelID))
slog.Info("Starting external container", slog.String("name", name), slog.String("pipeline", pipeline), slog.String("modelID", modelID))
w.externalContainers[name] = rc

return nil
Expand Down Expand Up @@ -300,24 +306,33 @@ func (w *Worker) HasCapacity(pipeline, modelID string) bool {
}

// Check if we have capacity for external containers.
name := dockerContainerName(pipeline, modelID)
w.mu.Lock()
defer w.mu.Unlock()
_, ok := w.externalContainers[name]
for _, rc := range w.externalContainers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
if rc.Capacity > 0 {
return true
}
}
}

return ok
//no managed or external containers have capacity
return false
}

func (w *Worker) borrowContainer(ctx context.Context, pipeline, modelID string) (*RunnerContainer, error) {
w.mu.Lock()

name := dockerContainerName(pipeline, modelID)
rc, ok := w.externalContainers[name]
if ok {
w.mu.Unlock()
// We allow concurrent in-flight requests for external containers and assume that it knows
// how to handle them
return rc, nil
for key, rc := range w.externalContainers {
if rc.Pipeline == pipeline && rc.ModelID == modelID {
// The current implementation of ai-runner containers does not have a queue so only do one request at a time to each container
if rc.Capacity > 0 {
slog.Info("selecting container to run request", slog.Int("type", int(rc.Type)), slog.Int("capacity", rc.Capacity), slog.String("url", rc.Endpoint.URL))
w.externalContainers[key].Capacity -= 1
w.mu.Unlock()
return rc, nil
}
}
}

w.mu.Unlock()
Expand All @@ -326,10 +341,19 @@ func (w *Worker) borrowContainer(ctx context.Context, pipeline, modelID string)
}

func (w *Worker) returnContainer(rc *RunnerContainer) {
slog.Info("returning container to be available", slog.Int("type", int(rc.Type)), slog.Int("capacity", rc.Capacity), slog.String("url", rc.Endpoint.URL))

switch rc.Type {
case Managed:
w.manager.Return(rc)
case External:
// Noop because we allow concurrent in-flight requests for external containers
w.mu.Lock()
defer w.mu.Unlock()
//free external container for next request
for key, _ := range w.externalContainers {
if w.externalContainers[key].Endpoint.URL == rc.Endpoint.URL {
w.externalContainers[key].Capacity += 1
}
}
}
}