diff --git a/flyteidl/protos/flyteidl/plugins/ray.proto b/flyteidl/protos/flyteidl/plugins/ray.proto index 749444ee04..a9dcb2f148 100644 --- a/flyteidl/protos/flyteidl/plugins/ray.proto +++ b/flyteidl/protos/flyteidl/plugins/ray.proto @@ -20,6 +20,15 @@ message RayJob { // RuntimeEnvYAML represents the runtime environment configuration // provided as a multi-line YAML string. string runtime_env_yaml = 5; + // address specifies the Ray head address to connect to for an existing cluster. + // When set, the RayJob submits to an existing RayCluster instead of creating a new one. + // The address is parsed to derive a cluster selector label. + // Supported formats: + // - Cluster name: "my-cluster" -> clusterSelector: {"ray.io/cluster": "my-cluster"} + // - Head service DNS: "my-cluster-head-svc" -> clusterSelector: {"ray.io/cluster": "my-cluster"} + // - Ray client URL: "ray://my-cluster-head-svc:10001" -> clusterSelector: {"ray.io/cluster": "my-cluster"} + // This enables long-lived clusters for faster iteration during development. + string address = 6; } // Define Ray cluster defines the desired state of RayCluster diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index defd7c1e85..d19e29f332 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -124,10 +124,106 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return rayjob, err } +// parseAddressToClusterName extracts a cluster name from a Ray address string. +// Supported formats: +// - Cluster name: "my-cluster" -> "my-cluster" +// - Head service DNS: "my-cluster-head-svc" -> "my-cluster" +// - Ray client URL: "ray://my-cluster-head-svc:10001" -> "my-cluster" +// - Full k8s DNS: "my-cluster-head-svc.namespace.svc.cluster.local:10001" -> "my-cluster" +// +// Author: Devin AI (claude-sonnet-4-20250514) +func parseAddressToClusterName(address string) (string, error) { + if address == "" { + return "", fmt.Errorf("address is empty") + } + + // Handle URL format (ray://host:port or http://host:port) + host := address + if strings.Contains(address, "://") { + parts := strings.SplitN(address, "://", 2) + if len(parts) == 2 { + host = parts[1] + } + } + + // Remove port if present + if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 { + // Make sure it's a port (not part of IPv6) + if !strings.Contains(host[colonIdx:], "]") { + host = host[:colonIdx] + } + } + + // Remove path if present + if slashIdx := strings.Index(host, "/"); slashIdx != -1 { + host = host[:slashIdx] + } + + // Take the first DNS label (before any dots) + if dotIdx := strings.Index(host, "."); dotIdx != -1 { + host = host[:dotIdx] + } + + // Strip "-head-svc" suffix if present + clusterName := strings.TrimSuffix(host, "-head-svc") + + if clusterName == "" { + return "", fmt.Errorf("could not parse cluster name from address: %s", address) + } + + return clusterName, nil +} + func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.RayJob, objectMeta *metav1.ObjectMeta, taskPodSpec v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) { - enableIngress := true cfg := GetConfig() + // Handle runtime_env conversion (needed for both modes) + var runtimeEnvYaml string + var err error + runtimeEnvYaml = rayJob.GetRuntimeEnvYaml() + // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml + if rayJob.GetRuntimeEnv() != "" && rayJob.GetRuntimeEnvYaml() == "" { + runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.GetRuntimeEnv()) + if err != nil { + return nil, err + } + } + + submitterPodSpec := taskPodSpec.DeepCopy() + submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx) + + // Check if address is provided - if so, use existing cluster mode + if address := rayJob.GetAddress(); address != "" { + clusterName, err := parseAddressToClusterName(address) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid Ray address for cluster selection: %v", err) + } + + // Use ray.io/cluster label to select the existing cluster + clusterSelector := map[string]string{ + "ray.io/cluster": clusterName, + } + + jobSpec := rayv1.RayJobSpec{ + ClusterSelector: clusterSelector, + Entrypoint: strings.Join(primaryContainer.Args, " "), + RuntimeEnvYAML: runtimeEnvYaml, + SubmitterPodTemplate: &submitterPodTemplate, + } + + return &rayv1.RayJob{ + TypeMeta: metav1.TypeMeta{ + Kind: KindRayJob, + APIVersion: rayv1.SchemeGroupVersion.String(), + }, + Spec: jobSpec, + ObjectMeta: *objectMeta, + }, nil + } + + // Default mode: create a new RayCluster + enableIngress := true + headPodSpec := taskPodSpec.DeepCopy() headPodTemplate, err := buildHeadPodTemplate( &headPodSpec.Containers[primaryContainerIdx], @@ -217,20 +313,6 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.R ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished } - submitterPodSpec := taskPodSpec.DeepCopy() - submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx) - - // TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto. - var runtimeEnvYaml string - runtimeEnvYaml = rayJob.GetRuntimeEnvYaml() - // If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml - if rayJob.GetRuntimeEnv() != "" && rayJob.GetRuntimeEnvYaml() == "" { - runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.GetRuntimeEnv()) - if err != nil { - return nil, err - } - } - jobSpec := rayv1.RayJobSpec{ RayClusterSpec: &rayClusterSpec, Entrypoint: strings.Join(primaryContainer.Args, " "), diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 39030a9f3d..902aad775e 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -1429,6 +1429,118 @@ func TestGetPropertiesRay(t *testing.T) { assert.Equal(t, expected, rayJobResourceHandler.GetProperties()) } +// TestParseAddressToClusterName tests the parseAddressToClusterName function +// which extracts a cluster name from various Ray address formats. +// Author: Devin AI (claude-sonnet-4-20250514) +func TestParseAddressToClusterName(t *testing.T) { + testCases := []struct { + name string + address string + expected string + expectError bool + }{ + { + name: "simple cluster name", + address: "my-cluster", + expected: "my-cluster", + }, + { + name: "head service DNS name", + address: "my-cluster-head-svc", + expected: "my-cluster", + }, + { + name: "ray client URL with port", + address: "ray://my-cluster-head-svc:10001", + expected: "my-cluster", + }, + { + name: "full k8s DNS with namespace", + address: "my-cluster-head-svc.namespace.svc.cluster.local:10001", + expected: "my-cluster", + }, + { + name: "ray URL without port", + address: "ray://my-cluster-head-svc", + expected: "my-cluster", + }, + { + name: "http URL", + address: "http://my-cluster-head-svc:8265", + expected: "my-cluster", + }, + { + name: "cluster name with dashes", + address: "my-long-cluster-name", + expected: "my-long-cluster-name", + }, + { + name: "cluster name with head-svc suffix and dashes", + address: "my-long-cluster-name-head-svc", + expected: "my-long-cluster-name", + }, + { + name: "empty address", + address: "", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := parseAddressToClusterName(tc.address) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expected, result) + } + }) + } +} + +// TestBuildResourceRayWithAddress tests that when address is provided, +// the RayJob uses ClusterSelector instead of creating a new RayCluster. +// Author: Devin AI (claude-sonnet-4-20250514) +func TestBuildResourceRayWithAddress(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + + rayJob := &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3}}, + }, + Address: "ray://my-existing-cluster-head-svc:10001", + RuntimeEnvYaml: "pip:\n - numpy", + } + + taskTemplate := dummyRayTaskTemplate("ray-id", rayJob) + + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{}) + assert.Nil(t, err) + + rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount) + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx) + assert.Nil(t, err) + assert.NotNil(t, RayResource) + + ray, ok := RayResource.(*rayv1.RayJob) + assert.True(t, ok) + + // When address is provided, ClusterSelector should be set + assert.NotNil(t, ray.Spec.ClusterSelector, "ClusterSelector should be set when address is provided") + assert.Equal(t, map[string]string{"ray.io/cluster": "my-existing-cluster"}, ray.Spec.ClusterSelector) + + // RayClusterSpec should be nil (no new cluster created) + assert.Nil(t, ray.Spec.RayClusterSpec, "RayClusterSpec should be nil when using address") + + // RuntimeEnvYAML should still be set + assert.Equal(t, "pip:\n - numpy", ray.Spec.RuntimeEnvYAML) + + // SubmitterPodTemplate should still be set + assert.NotNil(t, ray.Spec.SubmitterPodTemplate, "SubmitterPodTemplate should still be set") +} + func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { data, err := json.Marshal(obj) assert.Nil(t, err)