From 4be29bce3f95859d60f56109b08efb564293ea3b Mon Sep 17 00:00:00 2001 From: "Shiva Krishna, Merla" Date: Thu, 8 Aug 2024 06:11:46 -0700 Subject: [PATCH 1/2] Add support to cache "all" profiles Signed-off-by: Shiva Krishna, Merla --- internal/controller/nimcache_controller.go | 14 +++++--- .../controller/nimcache_controller_test.go | 22 ++++++++++++ internal/utils/utils.go | 10 ++++++ internal/utils/utils_test.go | 34 +++++++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/internal/controller/nimcache_controller.go b/internal/controller/nimcache_controller.go index 7f71e18c..41a1eef1 100644 --- a/internal/controller/nimcache_controller.go +++ b/internal/controller/nimcache_controller.go @@ -32,6 +32,7 @@ import ( "github.com/NVIDIA/k8s-nim-operator/internal/nimparser" "github.com/NVIDIA/k8s-nim-operator/internal/render" "github.com/NVIDIA/k8s-nim-operator/internal/shared" + "github.com/NVIDIA/k8s-nim-operator/internal/utils" "github.com/go-logr/logr" "gopkg.in/yaml.v2" batchv1 "k8s.io/api/batch/v1" @@ -463,7 +464,7 @@ func (r *NIMCacheReconciler) reconcileJobStatus(ctx context.Context, nimCache *a return fmt.Errorf("failed to get selected profiles: %w", err) } - if len(selectedProfiles) > 0 { + if len(selectedProfiles) > 0 && !utils.ContainsElement(selectedProfiles, "all") { nimManifest, err := r.extractNIMManifest(ctx, getManifestConfigName(nimCache), nimCache.GetNamespace()) if err != nil { return fmt.Errorf("failed to get model manifest config file: %w", err) @@ -819,9 +820,14 @@ func constructJob(nimCache *appsv1alpha1.NIMCache) (*batchv1.Job, error) { if err != nil { return nil, err } - if selectedProfiles != nil { - job.Spec.Template.Spec.Containers[0].Args = []string{"--profiles"} - job.Spec.Template.Spec.Containers[0].Args = append(job.Spec.Template.Spec.Containers[0].Args, selectedProfiles...) + + if len(selectedProfiles) > 0 { + if utils.ContainsElement(selectedProfiles, "all") { + job.Spec.Template.Spec.Containers[0].Args = []string{"--all"} + } else { + job.Spec.Template.Spec.Containers[0].Args = []string{"--profiles"} + job.Spec.Template.Spec.Containers[0].Args = append(job.Spec.Template.Spec.Containers[0].Args, selectedProfiles...) + } } } return job, nil diff --git a/internal/controller/nimcache_controller_test.go b/internal/controller/nimcache_controller_test.go index 5a5b9367..46d32cba 100644 --- a/internal/controller/nimcache_controller_test.go +++ b/internal/controller/nimcache_controller_test.go @@ -305,6 +305,28 @@ var _ = Describe("NIMCache Controller", func() { Expect(job.Spec.Template.Spec.Volumes[0].VolumeSource.PersistentVolumeClaim.ClaimName).To(Equal(getPvcName(nimCache, nimCache.Spec.Storage.PVC))) }) + It("should construct a job set to download all profiles", func() { + profiles := []string{"all"} + nimCache := &appsv1alpha1.NIMCache{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-nimcache", + Namespace: "default", + }, + Spec: appsv1alpha1.NIMCacheSpec{ + Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret", Model: appsv1alpha1.ModelSpec{Profiles: profiles, AutoDetect: ptr.To[bool](false)}}}, + }, + } + + job, err := constructJob(nimCache) + Expect(err).ToNot(HaveOccurred()) + + Expect(job.Name).To(Equal(getJobName(nimCache))) + Expect(job.Spec.Template.Spec.Containers[0].Image).To(Equal("nvcr.io/nim:test")) + Expect(job.Spec.Template.Spec.ImagePullSecrets[0].Name).To(Equal("my-secret")) + Expect(job.Spec.Template.Spec.Containers[0].Command).To(ContainElements("download-to-cache")) + Expect(job.Spec.Template.Spec.Containers[0].Args).To(ContainElements("--all")) + }) + It("should create a job with the correct specifications", func() { ctx := context.TODO() nimCache := &appsv1alpha1.NIMCache{ diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 5b0456fe..e0b18b37 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -217,3 +217,13 @@ func IsSpecChanged(current client.Object, desired client.Object) bool { return false } + +// ContainsElement checks if an element exists in a slice +func ContainsElement[T comparable](slice []T, element T) bool { + for _, value := range slice { + if value == element { + return true + } + } + return false +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 5e578e3c..603cae1f 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -448,3 +448,37 @@ func TestIsSpecChanged(t *testing.T) { }) } } + +// TestContainsElement tests the ContainsElement function +func TestContainsElement(t *testing.T) { + // Test cases + tests := []struct { + name string + slice interface{} + element interface{} + expected bool + }{ + {"IntExists", []int{1, 2, 3, 4, 5}, 3, true}, + {"IntDoesNotExist", []int{1, 2, 3, 4, 5}, 6, false}, + {"StringExists", []string{"llama", "mistral", "gemini"}, "llama", true}, + {"StringDoesNotExist", []string{"llama", "mistral", "gemini"}, "arctic", false}, + {"EmptySlice", []int{}, 1, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switch slice := tt.slice.(type) { + case []int: + result := ContainsElement(slice, tt.element.(int)) + if result != tt.expected { + t.Errorf("Contains(%v, %v) = %v; expected %v", slice, tt.element, result, tt.expected) + } + case []string: + result := ContainsElement(slice, tt.element.(string)) + if result != tt.expected { + t.Errorf("Contains(%v, %v) = %v; expected %v", slice, tt.element, result, tt.expected) + } + } + }) + } +} From 4e5a8d639d5d2f12258a49b554588c22fc47c4e6 Mon Sep 17 00:00:00 2001 From: "Shiva Krishna, Merla" Date: Thu, 15 Aug 2024 07:45:32 -0700 Subject: [PATCH 2/2] Add const for "all" profiles and update tests Signed-off-by: Shiva Krishna, Merla --- internal/controller/nimcache_controller.go | 7 +++- .../controller/nimcache_controller_test.go | 39 +++++++++++++++++-- .../controller/testdata/manifest_trtllm.yaml | 14 +++++++ 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/internal/controller/nimcache_controller.go b/internal/controller/nimcache_controller.go index 41a1eef1..f0917951 100644 --- a/internal/controller/nimcache_controller.go +++ b/internal/controller/nimcache_controller.go @@ -57,6 +57,9 @@ const ( // NIMCacheFinalizer is the finalizer annotation NIMCacheFinalizer = "finalizer.nimcache.apps.nvidia.com" + + // AllProfiles represents all profiles in the NIM manifest + AllProfiles = "all" ) // NIMCacheReconciler reconciles a NIMCache object @@ -464,7 +467,7 @@ func (r *NIMCacheReconciler) reconcileJobStatus(ctx context.Context, nimCache *a return fmt.Errorf("failed to get selected profiles: %w", err) } - if len(selectedProfiles) > 0 && !utils.ContainsElement(selectedProfiles, "all") { + if len(selectedProfiles) > 0 && !utils.ContainsElement(selectedProfiles, AllProfiles) { nimManifest, err := r.extractNIMManifest(ctx, getManifestConfigName(nimCache), nimCache.GetNamespace()) if err != nil { return fmt.Errorf("failed to get model manifest config file: %w", err) @@ -822,7 +825,7 @@ func constructJob(nimCache *appsv1alpha1.NIMCache) (*batchv1.Job, error) { } if len(selectedProfiles) > 0 { - if utils.ContainsElement(selectedProfiles, "all") { + if utils.ContainsElement(selectedProfiles, AllProfiles) { job.Spec.Template.Spec.Containers[0].Args = []string{"--all"} } else { job.Spec.Template.Spec.Containers[0].Args = []string{"--profiles"} diff --git a/internal/controller/nimcache_controller_test.go b/internal/controller/nimcache_controller_test.go index 46d32cba..217b3100 100644 --- a/internal/controller/nimcache_controller_test.go +++ b/internal/controller/nimcache_controller_test.go @@ -274,7 +274,7 @@ var _ = Describe("NIMCache Controller", func() { }, time.Second*10).Should(Succeed()) }) - It("should construct a job with right specifications", func() { + It("should construct a job with a specific profile", func() { profiles := []string{"36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff"} profilesJSON, err := json.Marshal(profiles) Expect(err).ToNot(HaveOccurred()) @@ -305,8 +305,39 @@ var _ = Describe("NIMCache Controller", func() { Expect(job.Spec.Template.Spec.Volumes[0].VolumeSource.PersistentVolumeClaim.ClaimName).To(Equal(getPvcName(nimCache, nimCache.Spec.Storage.PVC))) }) + It("should construct a job with multiple profiles", func() { + profiles := []string{"36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff", "04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345"} + profilesJSON, err := json.Marshal(profiles) + Expect(err).ToNot(HaveOccurred()) + + nimCache := &appsv1alpha1.NIMCache{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-nimcache", + Namespace: "default", + Annotations: map[string]string{SelectedNIMProfilesAnnotationKey: string(profilesJSON)}, + }, + Spec: appsv1alpha1.NIMCacheSpec{ + Source: appsv1alpha1.NIMSource{NGC: &appsv1alpha1.NGCSource{ModelPuller: "nvcr.io/nim:test", PullSecret: "my-secret", Model: appsv1alpha1.ModelSpec{AutoDetect: ptr.To[bool](true)}}}, + }, + } + + job, err := constructJob(nimCache) + Expect(err).ToNot(HaveOccurred()) + + Expect(job.Name).To(Equal(getJobName(nimCache))) + Expect(job.Spec.Template.Spec.Containers[0].Image).To(Equal("nvcr.io/nim:test")) + Expect(job.Spec.Template.Spec.ImagePullSecrets[0].Name).To(Equal("my-secret")) + Expect(job.Spec.Template.Spec.Containers[0].Command).To(ContainElements("download-to-cache")) + Expect(job.Spec.Template.Spec.Containers[0].Args).To(ContainElements("--profiles", "36fc1fa4fc35c1d54da115a39323080b08d7937dceb8ba47be44f4da0ec720ff", "04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345")) + Expect(*job.Spec.Template.Spec.SecurityContext.RunAsUser).To(Equal(int64(1000))) + Expect(*job.Spec.Template.Spec.SecurityContext.FSGroup).To(Equal(int64(2000))) + Expect(*job.Spec.Template.Spec.SecurityContext.RunAsNonRoot).To(Equal(true)) + Expect(job.Spec.Template.Spec.Volumes[0].Name).To(Equal("nim-cache-volume")) + Expect(job.Spec.Template.Spec.Volumes[0].VolumeSource.PersistentVolumeClaim.ClaimName).To(Equal(getPvcName(nimCache, nimCache.Spec.Storage.PVC))) + }) + It("should construct a job set to download all profiles", func() { - profiles := []string{"all"} + profiles := []string{AllProfiles} nimCache := &appsv1alpha1.NIMCache{ ObjectMeta: metav1.ObjectMeta{ Name: "test-nimcache", @@ -367,7 +398,7 @@ var _ = Describe("NIMCache Controller", func() { filePath := filepath.Join("testdata", "manifest_trtllm.yaml") manifestData, err := nimparser.ParseModelManifest(filePath) Expect(err).NotTo(HaveOccurred()) - Expect(*manifestData).To(HaveLen(1)) + Expect(*manifestData).To(HaveLen(2)) err = reconciler.createManifestConfigMap(ctx, nimCache, manifestData) Expect(err).NotTo(HaveOccurred()) @@ -382,7 +413,7 @@ var _ = Describe("NIMCache Controller", func() { extractedManifest, err := reconciler.extractNIMManifest(ctx, createdConfigMap.Name, createdConfigMap.Namespace) Expect(err).NotTo(HaveOccurred()) Expect(extractedManifest).NotTo(BeNil()) - Expect(*extractedManifest).To(HaveLen(1)) + Expect(*extractedManifest).To(HaveLen(2)) profile, exists := (*extractedManifest)["03fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc29b61"] Expect(exists).To(BeTrue()) Expect(profile.Model).To(Equal("meta/llama3-70b-instruct")) diff --git a/internal/controller/testdata/manifest_trtllm.yaml b/internal/controller/testdata/manifest_trtllm.yaml index ac9aa3c8..06b77fdb 100644 --- a/internal/controller/testdata/manifest_trtllm.yaml +++ b/internal/controller/testdata/manifest_trtllm.yaml @@ -12,3 +12,17 @@ profile: throughput tp: '8' container_url: nvcr.io/nim/meta/llama3-70b-instruct:1.0.0 +04fdb4d11f01be10c31b00e7c0540e2835e89a0079b483ad2dd3c25c8cc12345: + model: meta/llama3-70b-instruct + release: '1.0.0' + tags: + feat_lora: 'false' + feat_lora_max_rank: '32' + gpu: A100 + gpu_device: 26b5:10de + llm_engine: tensorrt_llm + pp: '1' + precision: fp16 + profile: throughput + tp: '8' + container_url: nvcr.io/nim/meta/llama3-70b-instruct:1.0.0