Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flyteidl/protos/flyteidl/plugins/ray.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ message RayJob {
// RuntimeEnvYAML represents the runtime environment configuration
// provided as a multi-line YAML string.
string runtime_env_yaml = 5;
// address specifies the Ray head address to connect to for an existing cluster.
// When set, the RayJob submits to an existing RayCluster instead of creating a new one.
// The address is parsed to derive a cluster selector label.
// Supported formats:
// - Cluster name: "my-cluster" -> clusterSelector: {"ray.io/cluster": "my-cluster"}
// - Head service DNS: "my-cluster-head-svc" -> clusterSelector: {"ray.io/cluster": "my-cluster"}
// - Ray client URL: "ray://my-cluster-head-svc:10001" -> clusterSelector: {"ray.io/cluster": "my-cluster"}
// This enables long-lived clusters for faster iteration during development.
string address = 6;
}

// Define Ray cluster defines the desired state of RayCluster
Expand Down
112 changes: 97 additions & 15 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,106 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC
return rayjob, err
}

// parseAddressToClusterName extracts a cluster name from a Ray address string.
// Supported formats:
// - Cluster name: "my-cluster" -> "my-cluster"
// - Head service DNS: "my-cluster-head-svc" -> "my-cluster"
// - Ray client URL: "ray://my-cluster-head-svc:10001" -> "my-cluster"
// - Full k8s DNS: "my-cluster-head-svc.namespace.svc.cluster.local:10001" -> "my-cluster"
//
// Author: Devin AI (claude-sonnet-4-20250514)
func parseAddressToClusterName(address string) (string, error) {
if address == "" {
return "", fmt.Errorf("address is empty")
}

// Handle URL format (ray://host:port or http://host:port)
host := address
if strings.Contains(address, "://") {
parts := strings.SplitN(address, "://", 2)
if len(parts) == 2 {
host = parts[1]
}
}

// Remove port if present
if colonIdx := strings.LastIndex(host, ":"); colonIdx != -1 {
// Make sure it's a port (not part of IPv6)
if !strings.Contains(host[colonIdx:], "]") {
host = host[:colonIdx]
}
}

// Remove path if present
if slashIdx := strings.Index(host, "/"); slashIdx != -1 {
host = host[:slashIdx]
}

// Take the first DNS label (before any dots)
if dotIdx := strings.Index(host, "."); dotIdx != -1 {
host = host[:dotIdx]
}

// Strip "-head-svc" suffix if present
clusterName := strings.TrimSuffix(host, "-head-svc")

if clusterName == "" {
return "", fmt.Errorf("could not parse cluster name from address: %s", address)
}

return clusterName, nil
}

func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.RayJob, objectMeta *metav1.ObjectMeta, taskPodSpec v1.PodSpec, headNodeRayStartParams map[string]string, primaryContainerIdx int, primaryContainer v1.Container) (*rayv1.RayJob, error) {
enableIngress := true
cfg := GetConfig()

// Handle runtime_env conversion (needed for both modes)
var runtimeEnvYaml string
var err error
runtimeEnvYaml = rayJob.GetRuntimeEnvYaml()
// If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml
if rayJob.GetRuntimeEnv() != "" && rayJob.GetRuntimeEnvYaml() == "" {
runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.GetRuntimeEnv())
if err != nil {
return nil, err
}
}

submitterPodSpec := taskPodSpec.DeepCopy()
submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx)

// Check if address is provided - if so, use existing cluster mode
if address := rayJob.GetAddress(); address != "" {
clusterName, err := parseAddressToClusterName(address)
if err != nil {
return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid Ray address for cluster selection: %v", err)
}

// Use ray.io/cluster label to select the existing cluster
clusterSelector := map[string]string{
"ray.io/cluster": clusterName,
}

jobSpec := rayv1.RayJobSpec{
ClusterSelector: clusterSelector,
Entrypoint: strings.Join(primaryContainer.Args, " "),
RuntimeEnvYAML: runtimeEnvYaml,
SubmitterPodTemplate: &submitterPodTemplate,
}

return &rayv1.RayJob{
TypeMeta: metav1.TypeMeta{
Kind: KindRayJob,
APIVersion: rayv1.SchemeGroupVersion.String(),
},
Spec: jobSpec,
ObjectMeta: *objectMeta,
}, nil
}

// Default mode: create a new RayCluster
enableIngress := true

headPodSpec := taskPodSpec.DeepCopy()
headPodTemplate, err := buildHeadPodTemplate(
&headPodSpec.Containers[primaryContainerIdx],
Expand Down Expand Up @@ -217,20 +313,6 @@ func constructRayJob(taskCtx pluginsCore.TaskExecutionContext, rayJob *plugins.R
ttlSecondsAfterFinished = &rayJob.TtlSecondsAfterFinished
}

submitterPodSpec := taskPodSpec.DeepCopy()
submitterPodTemplate := buildSubmitterPodTemplate(submitterPodSpec, objectMeta, taskCtx)

// TODO: This is for backward compatibility. Remove this block once runtime_env is removed from ray proto.
var runtimeEnvYaml string
runtimeEnvYaml = rayJob.GetRuntimeEnvYaml()
// If runtime_env exists but runtime_env_yaml does not, convert runtime_env to runtime_env_yaml
if rayJob.GetRuntimeEnv() != "" && rayJob.GetRuntimeEnvYaml() == "" {
runtimeEnvYaml, err = convertBase64RuntimeEnvToYaml(rayJob.GetRuntimeEnv())
if err != nil {
return nil, err
}
}

jobSpec := rayv1.RayJobSpec{
RayClusterSpec: &rayClusterSpec,
Entrypoint: strings.Join(primaryContainer.Args, " "),
Expand Down
112 changes: 112 additions & 0 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,118 @@ func TestGetPropertiesRay(t *testing.T) {
assert.Equal(t, expected, rayJobResourceHandler.GetProperties())
}

// TestParseAddressToClusterName tests the parseAddressToClusterName function
// which extracts a cluster name from various Ray address formats.
// Author: Devin AI (claude-sonnet-4-20250514)
func TestParseAddressToClusterName(t *testing.T) {
testCases := []struct {
name string
address string
expected string
expectError bool
}{
{
name: "simple cluster name",
address: "my-cluster",
expected: "my-cluster",
},
{
name: "head service DNS name",
address: "my-cluster-head-svc",
expected: "my-cluster",
},
{
name: "ray client URL with port",
address: "ray://my-cluster-head-svc:10001",
expected: "my-cluster",
},
{
name: "full k8s DNS with namespace",
address: "my-cluster-head-svc.namespace.svc.cluster.local:10001",
expected: "my-cluster",
},
{
name: "ray URL without port",
address: "ray://my-cluster-head-svc",
expected: "my-cluster",
},
{
name: "http URL",
address: "http://my-cluster-head-svc:8265",
expected: "my-cluster",
},
{
name: "cluster name with dashes",
address: "my-long-cluster-name",
expected: "my-long-cluster-name",
},
{
name: "cluster name with head-svc suffix and dashes",
address: "my-long-cluster-name-head-svc",
expected: "my-long-cluster-name",
},
{
name: "empty address",
address: "",
expectError: true,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result, err := parseAddressToClusterName(tc.address)
if tc.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.expected, result)
}
})
}
}

// TestBuildResourceRayWithAddress tests that when address is provided,
// the RayJob uses ClusterSelector instead of creating a new RayCluster.
// Author: Devin AI (claude-sonnet-4-20250514)
func TestBuildResourceRayWithAddress(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}

rayJob := &plugins.RayJob{
RayCluster: &plugins.RayCluster{
HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}},
WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3}},
},
Address: "ray://my-existing-cluster-head-svc:10001",
RuntimeEnvYaml: "pip:\n - numpy",
}

taskTemplate := dummyRayTaskTemplate("ray-id", rayJob)

err := config.SetK8sPluginConfig(&config.K8sPluginConfig{})
assert.Nil(t, err)

rayCtx := dummyRayTaskContext(taskTemplate, resourceRequirements, nil, "", serviceAccount)
RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), rayCtx)
assert.Nil(t, err)
assert.NotNil(t, RayResource)

ray, ok := RayResource.(*rayv1.RayJob)
assert.True(t, ok)

// When address is provided, ClusterSelector should be set
assert.NotNil(t, ray.Spec.ClusterSelector, "ClusterSelector should be set when address is provided")
assert.Equal(t, map[string]string{"ray.io/cluster": "my-existing-cluster"}, ray.Spec.ClusterSelector)

// RayClusterSpec should be nil (no new cluster created)
assert.Nil(t, ray.Spec.RayClusterSpec, "RayClusterSpec should be nil when using address")

// RuntimeEnvYAML should still be set
assert.Equal(t, "pip:\n - numpy", ray.Spec.RuntimeEnvYAML)

// SubmitterPodTemplate should still be set
assert.NotNil(t, ray.Spec.SubmitterPodTemplate, "SubmitterPodTemplate should still be set")
}

func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct {
data, err := json.Marshal(obj)
assert.Nil(t, err)
Expand Down
Loading