Skip to content

Commit

Permalink
Merge pull request #41 from shivamerla/cache_all_profiles
Browse files Browse the repository at this point in the history
Add support to cache "all" profiles
  • Loading branch information
ArangoGutierrez authored Aug 16, 2024
2 parents 0672d66 + 4e5a8d6 commit 6c7b993
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 7 deletions.
17 changes: 13 additions & 4 deletions internal/controller/nimcache_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/NVIDIA/k8s-nim-operator/internal/nimparser"
"github.com/NVIDIA/k8s-nim-operator/internal/render"
"github.com/NVIDIA/k8s-nim-operator/internal/shared"
"github.com/NVIDIA/k8s-nim-operator/internal/utils"
"github.com/go-logr/logr"
"gopkg.in/yaml.v2"
batchv1 "k8s.io/api/batch/v1"
Expand All @@ -56,6 +57,9 @@ const (

// NIMCacheFinalizer is the finalizer annotation
NIMCacheFinalizer = "finalizer.nimcache.apps.nvidia.com"

// AllProfiles represents all profiles in the NIM manifest
AllProfiles = "all"
)

// NIMCacheReconciler reconciles a NIMCache object
Expand Down Expand Up @@ -463,7 +467,7 @@ func (r *NIMCacheReconciler) reconcileJobStatus(ctx context.Context, nimCache *a
return fmt.Errorf("failed to get selected profiles: %w", err)
}

if len(selectedProfiles) > 0 {
if len(selectedProfiles) > 0 && !utils.ContainsElement(selectedProfiles, AllProfiles) {
nimManifest, err := r.extractNIMManifest(ctx, getManifestConfigName(nimCache), nimCache.GetNamespace())
if err != nil {
return fmt.Errorf("failed to get model manifest config file: %w", err)
Expand Down Expand Up @@ -819,9 +823,14 @@ func constructJob(nimCache *appsv1alpha1.NIMCache) (*batchv1.Job, error) {
if err != nil {
return nil, err
}
if selectedProfiles != nil {
job.Spec.Template.Spec.Containers[0].Args = []string{"--profiles"}
job.Spec.Template.Spec.Containers[0].Args = append(job.Spec.Template.Spec.Containers[0].Args, selectedProfiles...)

if len(selectedProfiles) > 0 {
if utils.ContainsElement(selectedProfiles, AllProfiles) {
job.Spec.Template.Spec.Containers[0].Args = []string{"--all"}
} else {
job.Spec.Template.Spec.Containers[0].Args = []string{"--profiles"}
job.Spec.Template.Spec.Containers[0].Args = append(job.Spec.Template.Spec.Containers[0].Args, selectedProfiles...)
}
}
}
return job, nil
Expand Down
59 changes: 56 additions & 3 deletions internal/controller/nimcache_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ var _ = Describe("NIMCache Controller", func() {
}, time.Second*10).Should(Succeed())
})

It("should construct a job with right specifications", func() {
It("should construct a job with a specific profile", func() {
profiles := []string{"36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff"}
profilesJSON, err := json.Marshal(profiles)
Expect(err).ToNot(HaveOccurred())
Expand Down Expand Up @@ -305,6 +305,59 @@ var _ = Describe("NIMCache Controller", func() {
Expect(job.Spec.Template.Spec.Volumes[0].VolumeSource.PersistentVolumeClaim.ClaimName).To(Equal(getPvcName(nimCache, nimCache.Spec.Storage.PVC)))
})

It("should construct a job with multiple profiles", func() {
profiles := []string{"36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff", "04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345"}
profilesJSON, err := json.Marshal(profiles)
Expect(err).ToNot(HaveOccurred())

nimCache := &appsv1alpha1.NIMCache{
ObjectMeta: metav1.ObjectMeta{
Name: "test-nimcache",
Namespace: "default",
Annotations: map[string]string{SelectedNIMProfilesAnnotationKey: string(profilesJSON)},
},
Spec: appsv1alpha1.NIMCacheSpec{
Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret", Model: appsv1alpha1.ModelSpec{AutoDetect: ptr.To[bool](true)}}},
},
}

job, err := constructJob(nimCache)
Expect(err).ToNot(HaveOccurred())

Expect(job.Name).To(Equal(getJobName(nimCache)))
Expect(job.Spec.Template.Spec.Containers[0].Image).To(Equal("nvcr.io/nim:test"))
Expect(job.Spec.Template.Spec.ImagePullSecrets[0].Name).To(Equal("my-secret"))
Expect(job.Spec.Template.Spec.Containers[0].Command).To(ContainElements("download-to-cache"))
Expect(job.Spec.Template.Spec.Containers[0].Args).To(ContainElements("--profiles", "36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff", "04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345"))
Expect(*job.Spec.Template.Spec.SecurityContext.RunAsUser).To(Equal(int64(1000)))
Expect(*job.Spec.Template.Spec.SecurityContext.FSGroup).To(Equal(int64(2000)))
Expect(*job.Spec.Template.Spec.SecurityContext.RunAsNonRoot).To(Equal(true))
Expect(job.Spec.Template.Spec.Volumes[0].Name).To(Equal("nim-cache-volume"))
Expect(job.Spec.Template.Spec.Volumes[0].VolumeSource.PersistentVolumeClaim.ClaimName).To(Equal(getPvcName(nimCache, nimCache.Spec.Storage.PVC)))
})

It("should construct a job set to download all profiles", func() {
profiles := []string{AllProfiles}
nimCache := &appsv1alpha1.NIMCache{
ObjectMeta: metav1.ObjectMeta{
Name: "test-nimcache",
Namespace: "default",
},
Spec: appsv1alpha1.NIMCacheSpec{
Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret", Model: appsv1alpha1.ModelSpec{Profiles: profiles, AutoDetect: ptr.To[bool](false)}}},
},
}

job, err := constructJob(nimCache)
Expect(err).ToNot(HaveOccurred())

Expect(job.Name).To(Equal(getJobName(nimCache)))
Expect(job.Spec.Template.Spec.Containers[0].Image).To(Equal("nvcr.io/nim:test"))
Expect(job.Spec.Template.Spec.ImagePullSecrets[0].Name).To(Equal("my-secret"))
Expect(job.Spec.Template.Spec.Containers[0].Command).To(ContainElements("download-to-cache"))
Expect(job.Spec.Template.Spec.Containers[0].Args).To(ContainElements("--all"))
})

It("should create a job with the correct specifications", func() {
ctx := context.TODO()
nimCache := &appsv1alpha1.NIMCache{
Expand Down Expand Up @@ -345,7 +398,7 @@ var _ = Describe("NIMCache Controller", func() {
filePath := filepath.Join("testdata", "manifest_trtllm.yaml")
manifestData, err := nimparser.ParseModelManifest(filePath)
Expect(err).NotTo(HaveOccurred())
Expect(*manifestData).To(HaveLen(1))
Expect(*manifestData).To(HaveLen(2))

err = reconciler.createManifestConfigMap(ctx, nimCache, manifestData)
Expect(err).NotTo(HaveOccurred())
Expand All @@ -360,7 +413,7 @@ var _ = Describe("NIMCache Controller", func() {
extractedManifest, err := reconciler.extractNIMManifest(ctx, createdConfigMap.Name, createdConfigMap.Namespace)
Expect(err).NotTo(HaveOccurred())
Expect(extractedManifest).NotTo(BeNil())
Expect(*extractedManifest).To(HaveLen(1))
Expect(*extractedManifest).To(HaveLen(2))
profile, exists := (*extractedManifest)["03fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc29b61"]
Expect(exists).To(BeTrue())
Expect(profile.Model).To(Equal("meta/llama3-70b-instruct"))
Expand Down
14 changes: 14 additions & 0 deletions internal/controller/testdata/manifest_trtllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,17 @@
profile: throughput
tp: '8'
container_url: nvcr.io/nim/meta/llama3-70b-instruct:1.0.0
04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345:
model: meta/llama3-70b-instruct
release: '1.0.0'
tags:
feat_lora: 'false'
feat_lora_max_rank: '32'
gpu: A100
gpu_device: 26b5:10de
llm_engine: tensorrt_llm
pp: '1'
precision: fp16
profile: throughput
tp: '8'
container_url: nvcr.io/nim/meta/llama3-70b-instruct:1.0.0
10 changes: 10 additions & 0 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,13 @@ func IsSpecChanged(current client.Object, desired client.Object) bool {

return false
}

// ContainsElement checks if an element exists in a slice
func ContainsElement[T comparable](slice []T, element T) bool {
for _, value := range slice {
if value == element {
return true
}
}
return false
}
34 changes: 34 additions & 0 deletions internal/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,37 @@ func TestIsSpecChanged(t *testing.T) {
})
}
}

// TestContainsElement tests the ContainsElement function
func TestContainsElement(t *testing.T) {
// Test cases
tests := []struct {
name string
slice interface{}
element interface{}
expected bool
}{
{"IntExists", []int{1, 2, 3, 4, 5}, 3, true},
{"IntDoesNotExist", []int{1, 2, 3, 4, 5}, 6, false},
{"StringExists", []string{"llama", "mistral", "gemini"}, "llama", true},
{"StringDoesNotExist", []string{"llama", "mistral", "gemini"}, "arctic", false},
{"EmptySlice", []int{}, 1, false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
switch slice := tt.slice.(type) {
case []int:
result := ContainsElement(slice, tt.element.(int))
if result != tt.expected {
t.Errorf("Contains(%v, %v) = %v; expected %v", slice, tt.element, result, tt.expected)
}
case []string:
result := ContainsElement(slice, tt.element.(string))
if result != tt.expected {
t.Errorf("Contains(%v, %v) = %v; expected %v", slice, tt.element, result, tt.expected)
}
}
})
}
}

0 comments on commit 6c7b993

Please sign in to comment.