Skip to content

Commit 7f61c50

Browse files
tenzen-ysaileshd1402
authored andcommitted
KEP-2170: Implement TrainJob Reconciler to manage objects (kubeflow#2295)
* KEP-2170: Implement TrainJob Reconciler to manage objects Signed-off-by: Yuki Iwai <[email protected]> * Mode dep-crds to manifests/external-crds Signed-off-by: Yuki Iwai <[email protected]> * Rename run with runtime Signed-off-by: Yuki Iwai <[email protected]> --------- Signed-off-by: Yuki Iwai <[email protected]> Signed-off-by: sailesh duddupudi <[email protected]>
1 parent 7793706 commit 7f61c50

File tree

13 files changed

+410
-75
lines changed

13 files changed

+410
-75
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ bin/
44
/tf-operator
55
vendor/
66
testbin/*
7+
manifests/external-crds/
78
cover.out
89

910
# IDEs

Makefile

+19-3
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,16 @@ HAS_SETUP_ENVTEST := $(shell command -v setup-envtest;)
7676
testall: manifests generate fmt vet golangci-lint test ## Run tests.
7777

7878
test: envtest
79-
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./... -coverprofile cover.out
79+
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" \
80+
go test ./pkg/apis/kubeflow.org/v1/... ./pkg/cert/... ./pkg/common/... ./pkg/config/... ./pkg/controller.v1/... ./pkg/core/... ./pkg/util/... ./pkg/webhooks/... -coverprofile cover.out
8081

8182
.PHONY: test-integrationv2
82-
test-integrationv2: envtest
83+
test-integrationv2: envtest jobset-operator-crd scheduler-plugins-crd
8384
KUBEBUILDER_ASSETS="$(shell setup-envtest use $(ENVTEST_K8S_VERSION) -p path)" go test ./test/... -coverprofile cover.out
8485

8586
.PHONY: testv2
8687
testv2:
87-
go test ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out
88+
go test ./pkg/apis/kubeflow.org/v2alpha1/... ./pkg/controller.v2/... ./pkg/runtime.v2/... ./pkg/webhook.v2/... ./pkg/util.v2/... -coverprofile cover.out
8889

8990
envtest:
9091
ifndef HAS_SETUP_ENVTEST
@@ -129,3 +130,18 @@ controller-gen: ## Download controller-gen locally if necessary.
129130
KUSTOMIZE = $(shell pwd)/bin/kustomize
130131
kustomize: ## Download kustomize locally if necessary.
131132
GOBIN=$(PROJECT_DIR)/bin go install sigs.k8s.io/kustomize/kustomize/[email protected]
133+
134+
## Download external CRDs for the integration testings.
135+
EXTERNAL_CRDS_DIR ?= $(PROJECT_DIR)/manifests/external-crds
136+
137+
JOBSET_ROOT = $(shell go list -m -mod=readonly -f "{{.Dir}}" sigs.k8s.io/jobset)
138+
.PHONY: jobset-operator-crd
139+
jobset-operator-crd: ## Copy the CRDs from the jobset-operator to the manifests/external-crds directory.
140+
mkdir -p $(EXTERNAL_CRDS_DIR)/jobset-operator/
141+
cp -f $(JOBSET_ROOT)/config/components/crd/bases/* $(EXTERNAL_CRDS_DIR)/jobset-operator/
142+
143+
SCHEDULER_PLUGINS_ROOT = $(shell go list -m -f "{{.Dir}}" sigs.k8s.io/scheduler-plugins)
144+
.PHONY: scheduler-plugins-crd
145+
scheduler-plugins-crd:
146+
mkdir -p $(EXTERNAL_CRDS_DIR)/scheduler-plugins/
147+
cp -f $(SCHEDULER_PLUGINS_ROOT)/manifests/coscheduling/* $(EXTERNAL_CRDS_DIR)/scheduler-plugins

pkg/controller.v2/setup.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ func SetupControllers(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) (st
2626
if err := NewTrainJobReconciler(
2727
mgr.GetClient(),
2828
mgr.GetEventRecorderFor("training-operator-trainjob-controller"),
29-
).SetupWithManager(mgr, runtimes); err != nil {
29+
runtimes,
30+
).SetupWithManager(mgr); err != nil {
3031
return "TrainJob", err
3132
}
3233
return "", nil

pkg/controller.v2/trainjob_controller.go

+73-6
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,37 @@ package controllerv2
1818

1919
import (
2020
"context"
21+
"errors"
22+
"fmt"
2123

2224
"github.com/go-logr/logr"
25+
"k8s.io/apimachinery/pkg/runtime/schema"
2326
"k8s.io/client-go/tools/record"
2427
"k8s.io/klog/v2"
28+
"k8s.io/utils/ptr"
2529
ctrl "sigs.k8s.io/controller-runtime"
2630
"sigs.k8s.io/controller-runtime/pkg/client"
31+
"sigs.k8s.io/controller-runtime/pkg/client/apiutil"
2732

2833
kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
29-
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
34+
jobruntimes "github.com/kubeflow/training-operator/pkg/runtime.v2"
3035
)
3136

37+
var errorUnsupportedRuntime = errors.New("the specified runtime is not supported")
38+
3239
type TrainJobReconciler struct {
3340
log logr.Logger
3441
client client.Client
3542
recorder record.EventRecorder
43+
runtimes map[string]jobruntimes.Runtime
3644
}
3745

38-
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder) *TrainJobReconciler {
46+
func NewTrainJobReconciler(client client.Client, recorder record.EventRecorder, runtimes map[string]jobruntimes.Runtime) *TrainJobReconciler {
3947
return &TrainJobReconciler{
4048
log: ctrl.Log.WithName("trainjob-controller"),
4149
client: client,
4250
recorder: recorder,
51+
runtimes: runtimes,
4352
}
4453
}
4554

@@ -52,16 +61,74 @@ func (r *TrainJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (c
5261
return ctrl.Result{}, client.IgnoreNotFound(err)
5362
}
5463
log := ctrl.LoggerFrom(ctx).WithValues("trainJob", klog.KObj(&trainJob))
55-
ctrl.LoggerInto(ctx, log)
64+
ctx = ctrl.LoggerInto(ctx, log)
5665
log.V(2).Info("Reconciling TrainJob")
66+
if err := r.createOrUpdateObjs(ctx, &trainJob); err != nil {
67+
return ctrl.Result{}, err
68+
}
69+
// TODO (tenzen-y): Do update the status.
5770
return ctrl.Result{}, nil
5871
}
5972

60-
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager, runtimes map[string]runtime.Runtime) error {
73+
func (r *TrainJobReconciler) createOrUpdateObjs(ctx context.Context, trainJob *kubeflowv2.TrainJob) error {
74+
log := ctrl.LoggerFrom(ctx)
75+
76+
runtimeRefGK := runtimeRefToGroupKind(trainJob.Spec.RuntimeRef).String()
77+
runtime, ok := r.runtimes[runtimeRefGK]
78+
if !ok {
79+
return fmt.Errorf("%w: %s", errorUnsupportedRuntime, runtimeRefGK)
80+
}
81+
objs, err := runtime.NewObjects(ctx, trainJob)
82+
if err != nil {
83+
return err
84+
}
85+
for _, obj := range objs {
86+
var gvk schema.GroupVersionKind
87+
if gvk, err = apiutil.GVKForObject(obj.DeepCopyObject(), r.client.Scheme()); err != nil {
88+
return err
89+
}
90+
logKeysAndValues := []any{
91+
"groupVersionKind", gvk.String(),
92+
"namespace", obj.GetNamespace(),
93+
"name", obj.GetName(),
94+
}
95+
// TODO (tenzen-y): Ideally, we should use the SSA instead of checking existence.
96+
// Non-empty resourceVersion indicates UPDATE operation.
97+
var creationErr error
98+
var created bool
99+
if obj.GetResourceVersion() == "" {
100+
creationErr = r.client.Create(ctx, obj)
101+
created = creationErr == nil
102+
}
103+
switch {
104+
case created:
105+
log.V(5).Info("Succeeded to create object", logKeysAndValues)
106+
continue
107+
case client.IgnoreAlreadyExists(creationErr) != nil:
108+
return creationErr
109+
default:
110+
// This indicates CREATE operation has not been performed or the object has already existed in the cluster.
111+
if err = r.client.Update(ctx, obj); err != nil {
112+
return err
113+
}
114+
log.V(5).Info("Succeeded to update object", logKeysAndValues)
115+
}
116+
}
117+
return nil
118+
}
119+
120+
func runtimeRefToGroupKind(runtimeRef kubeflowv2.RuntimeRef) schema.GroupKind {
121+
return schema.GroupKind{
122+
Group: ptr.Deref(runtimeRef.APIGroup, ""),
123+
Kind: ptr.Deref(runtimeRef.Kind, ""),
124+
}
125+
}
126+
127+
func (r *TrainJobReconciler) SetupWithManager(mgr ctrl.Manager) error {
61128
b := ctrl.NewControllerManagedBy(mgr).
62129
For(&kubeflowv2.TrainJob{})
63-
for _, run := range runtimes {
64-
for _, registrar := range run.EventHandlerRegistrars() {
130+
for _, runtime := range r.runtimes {
131+
for _, registrar := range runtime.EventHandlerRegistrars() {
65132
if registrar != nil {
66133
b = registrar(b, mgr.GetClient())
67134
}

pkg/runtime.v2/core/clustertrainingruntime_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
4646
}{
4747
"succeeded to build JobSet and PodGroup": {
4848
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
49+
Suspend(true).
4950
UID("uid").
5051
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
5152
Trainer(
@@ -57,7 +58,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
5758
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
5859
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
5960
ContainerImage("test:runtime").
60-
PodGroupPolicySchedulingTimeout(120).
61+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
6162
MLPolicyNumNodes(20).
6263
ResourceRequests(0, corev1.ResourceList{
6364
corev1.ResourceCPU: resource.MustParse("1"),
@@ -69,6 +70,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
6970
).Obj(),
7071
wantObjs: []client.Object{
7172
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
73+
Suspend(true).
7274
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
7375
ContainerImage(ptr.To("test:trainjob")).
7476
JobCompletionMode(batchv1.IndexedCompletion).

pkg/runtime.v2/core/trainingruntime_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
4646
}{
4747
"succeeded to build JobSet and PodGroup": {
4848
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
49+
Suspend(true).
4950
UID("uid").
5051
RuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
5152
SpecLabel("conflictLabel", "override").
@@ -62,7 +63,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
6263
RuntimeSpec(
6364
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
6465
ContainerImage("test:runtime").
65-
PodGroupPolicySchedulingTimeout(120).
66+
PodGroupPolicyCoschedulingSchedulingTimeout(120).
6667
MLPolicyNumNodes(20).
6768
ResourceRequests(0, corev1.ResourceList{
6869
corev1.ResourceCPU: resource.MustParse("1"),
@@ -74,6 +75,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
7475
).Obj(),
7576
wantObjs: []client.Object{
7677
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
78+
Suspend(true).
7779
Label("conflictLabel", "override").
7880
Annotation("conflictAnnotation", "override").
7981
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").

pkg/runtime.v2/framework/core/framework_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -334,13 +334,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
334334
ResourceRequests(1, corev1.ResourceList{
335335
corev1.ResourceCPU: resource.MustParse("1"),
336336
corev1.ResourceMemory: resource.MustParse("2Gi"),
337-
}).
338-
Clone()
337+
})
339338
jobSetWithPropagatedTrainJobParams := jobSetBase.
339+
Clone().
340340
JobCompletionMode(batchv1.IndexedCompletion).
341341
ContainerImage(ptr.To("foo:bar")).
342-
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
343-
Clone()
342+
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid")
344343

345344
cases := map[string]struct {
346345
runtimeInfo *runtime.Info
@@ -361,6 +360,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
361360
Obj(),
362361
runtimeInfo: &runtime.Info{
363362
Obj: jobSetBase.
363+
Clone().
364364
Obj(),
365365
Policy: runtime.Policy{
366366
MLPolicy: &kubeflowv2.MLPolicy{
@@ -403,10 +403,12 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
403403
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
404404
Obj(),
405405
jobSetWithPropagatedTrainJobParams.
406+
Clone().
406407
Obj(),
407408
},
408409
wantRuntimeInfo: &runtime.Info{
409410
Obj: jobSetWithPropagatedTrainJobParams.
411+
Clone().
410412
Obj(),
411413
Policy: runtime.Policy{
412414
MLPolicy: &kubeflowv2.MLPolicy{

pkg/runtime.v2/framework/plugins/jobset/builder.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ import (
2828
)
2929

3030
type Builder struct {
31-
*jobsetv1alpha2.JobSet
31+
jobsetv1alpha2.JobSet
3232
}
3333

3434
func NewBuilder(objectKey client.ObjectKey, jobSetTemplateSpec kubeflowv2.JobSetTemplateSpec) *Builder {
3535
return &Builder{
36-
JobSet: &jobsetv1alpha2.JobSet{
36+
JobSet: jobsetv1alpha2.JobSet{
3737
TypeMeta: metav1.TypeMeta{
3838
APIVersion: jobsetv1alpha2.SchemeGroupVersion.String(),
3939
Kind: "JobSet",
@@ -76,8 +76,13 @@ func (b *Builder) PodLabels(labels map[string]string) *Builder {
7676
return b
7777
}
7878

79+
func (b *Builder) Suspend(suspend *bool) *Builder {
80+
b.Spec.Suspend = suspend
81+
return b
82+
}
83+
7984
// TODO: Need to support all TrainJob fields.
8085

8186
func (b *Builder) Build() *jobsetv1alpha2.JobSet {
82-
return b.JobSet
87+
return &b.JobSet
8388
}

pkg/runtime.v2/framework/plugins/jobset/jobset.go

+29-16
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,37 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
7676
if !ok {
7777
return nil, nil
7878
}
79-
jobSetBuilder := NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
80-
ObjectMeta: metav1.ObjectMeta{
81-
Labels: info.Labels,
82-
Annotations: info.Annotations,
83-
},
84-
Spec: raw.Spec,
85-
})
79+
80+
var jobSetBuilder *Builder
81+
oldJobSet := &jobsetv1alpha2.JobSet{}
82+
if err := j.client.Get(ctx, client.ObjectKeyFromObject(trainJob), oldJobSet); err != nil {
83+
if !apierrors.IsNotFound(err) {
84+
return nil, err
85+
}
86+
jobSetBuilder = NewBuilder(client.ObjectKeyFromObject(trainJob), kubeflowv2.JobSetTemplateSpec{
87+
ObjectMeta: metav1.ObjectMeta{
88+
Labels: info.Labels,
89+
Annotations: info.Annotations,
90+
},
91+
Spec: raw.Spec,
92+
})
93+
oldJobSet = nil
94+
} else {
95+
jobSetBuilder = &Builder{
96+
JobSet: *oldJobSet.DeepCopy(),
97+
}
98+
}
99+
86100
// TODO (tenzen-y): We should support all field propagation in builder.
87101
jobSet := jobSetBuilder.
102+
Suspend(trainJob.Spec.Suspend).
88103
ContainerImage(trainJob.Spec.Trainer.Image).
89104
JobCompletionMode(batchv1.IndexedCompletion).
90105
PodLabels(info.PodLabels).
91106
Build()
92107
if err := ctrlutil.SetControllerReference(trainJob, jobSet, j.scheme); err != nil {
93108
return nil, err
94109
}
95-
oldJobSet := &jobsetv1alpha2.JobSet{}
96-
if err := j.client.Get(ctx, client.ObjectKeyFromObject(jobSet), oldJobSet); err != nil {
97-
if !apierrors.IsNotFound(err) {
98-
return nil, err
99-
}
100-
oldJobSet = nil
101-
}
102110
if err := info.Update(jobSet); err != nil {
103111
return nil, err
104112
}
@@ -108,9 +116,14 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *kubefl
108116
return nil, nil
109117
}
110118

111-
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, suspended bool) bool {
119+
func needsCreateOrUpdate(old, new *jobsetv1alpha2.JobSet, trainJobIsSuspended bool) bool {
112120
return old == nil ||
113-
suspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations))
121+
(!trainJobIsSuspended && jobSetIsSuspended(old) && !jobSetIsSuspended(new)) ||
122+
(trainJobIsSuspended && (!equality.Semantic.DeepEqual(old.Spec, new.Spec) || !maps.Equal(old.Labels, new.Labels) || !maps.Equal(old.Annotations, new.Annotations)))
123+
}
124+
125+
func jobSetIsSuspended(jobSet *jobsetv1alpha2.JobSet) bool {
126+
return ptr.Deref(jobSet.Spec.Suspend, false)
114127
}
115128

116129
func (j *JobSet) ReconcilerBuilders() []runtime.ReconcilerBuilder {

0 commit comments

Comments
 (0)