Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cli/src/pcluster/aws/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,22 @@ def describe_image(self, ami_id):
return ImageInfo(images[0])
raise AWSClientError(function_name="describe_images", message=f"Image {ami_id} not found")

@AWSExceptionHandler.handle_client_exception
@Cache.cached
def describe_launch_template_version(self, launch_template_id, version):
"""Describe a specific launch template version and return its LaunchTemplateData."""
response = self._client.describe_launch_template_versions(
LaunchTemplateId=launch_template_id,
Versions=[str(version)],
)
versions = response.get("LaunchTemplateVersions", [])
if not versions:
raise AWSClientError(
function_name="describe_launch_template_versions",
message=f"Launch template {launch_template_id} version {version} not found",
)
return versions[0].get("LaunchTemplateData", {})

@AWSExceptionHandler.handle_client_exception
@Cache.cached
def describe_images(self, ami_ids, filters, owners):
Expand Down
65 changes: 65 additions & 0 deletions cli/src/pcluster/config/cluster_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@
InstanceTypePlacementGroupValidator,
InstanceTypeValidator,
KeyPairValidator,
LaunchTemplateOverridesValidator,
PlacementGroupCapacityReservationValidator,
PlacementGroupCapacityTypeValidator,
PlacementGroupNamingValidator,
Expand Down Expand Up @@ -1631,6 +1632,7 @@ def __init__(
self.managed_head_node_security_group = None
self.managed_compute_security_group = None
self.instance_types_data_version = ""
self.run_instances_overrides_version = ""

def _register_validators(self, context: ValidatorContext = None): # noqa: D102 #pylint: disable=unused-argument
self._register_validator(RegionValidator, region=self.region)
Expand Down Expand Up @@ -2222,6 +2224,15 @@ def scheduler_resources(self):
return str(files(__package__).parent / "resources" / "batch")


class LaunchTemplateOverrides(Resource):
"""Represent the LaunchTemplateOverrides configuration for a compute resource."""

def __init__(self, launch_template_id: str = None, version: int = None, **kwargs):
super().__init__(**kwargs)
self.launch_template_id = Resource.init_param(launch_template_id)
self.version = Resource.init_param(version)


class _BaseSlurmComputeResource(BaseComputeResource):
"""Represent the Slurm Compute Resource."""

Expand All @@ -2240,6 +2251,7 @@ def __init__(
tags: List[Tag] = None,
static_node_priority: int = None,
dynamic_node_priority: int = None,
launch_specification_overrides=None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -2260,6 +2272,7 @@ def __init__(
self.tags = tags
self.static_node_priority = Resource.init_param(static_node_priority, default=1)
self.dynamic_node_priority = Resource.init_param(dynamic_node_priority, default=1000)
self.launch_specification_overrides = launch_specification_overrides

@abstractmethod
def is_flexible(self) -> bool:
Expand Down Expand Up @@ -2362,6 +2375,15 @@ def _register_validators(self, context: ValidatorContext = None):
ec2memory=min_memory,
instance_type=smallest_type,
)
if self.launch_specification_overrides:
self._register_validator(
LaunchTemplateOverridesValidator,
launch_template_id=self.launch_specification_overrides.launch_template_id,
version=self.launch_specification_overrides.version,
instance_types=self.instance_types,
max_network_cards=self.max_network_cards,
is_flexible=self.is_flexible(),
)

def is_flexible(self):
"""Return True because the ComputeResource can contain multiple instance types."""
Expand Down Expand Up @@ -2449,6 +2471,15 @@ def _register_validators(self, context: ValidatorContext = None):
ec2memory=self._instance_type_info.ec2memory_size_in_mib(),
instance_type=self.instance_type,
)
if self.launch_specification_overrides:
self._register_validator(
LaunchTemplateOverridesValidator,
launch_template_id=self.launch_specification_overrides.launch_template_id,
version=self.launch_specification_overrides.version,
instance_types=self.instance_types,
max_network_cards=self.max_network_cards,
is_flexible=self.is_flexible(),
)

@property
def architecture(self) -> str:
Expand Down Expand Up @@ -2975,6 +3006,40 @@ def get_instance_types_data(self):
result[instance_type] = instance_type_info.instance_type_data
return result

def get_run_instances_overrides(self):
"""
Build run_instances_overrides data from LaunchTemplateOverrides config.

Iterates all queues and compute resources. For each compute resource that has
launch_specification_overrides configured, fetches the launch template data.

Returns a dict keyed by {queue_name} -> {compute_resource_name} -> {launch_template_data}.
Returns empty dict if no overrides are configured.
"""
overrides = {}
for queue in self.scheduling.queues:
for compute_resource in queue.compute_resources:
if not compute_resource.launch_specification_overrides:
continue

lt_overrides = compute_resource.launch_specification_overrides
lt_id = lt_overrides.launch_template_id
lt_version = lt_overrides.version

LOGGER.info(
"Fetching launch template %s version %s for queue %s, compute resource %s",
lt_id,
lt_version,
queue.name,
compute_resource.name,
)
lt_data = AWSApi.instance().ec2.describe_launch_template_version(lt_id, lt_version)

if lt_data:
overrides.setdefault(queue.name, {})[compute_resource.name] = lt_data

return overrides

@property
def login_nodes_ami(self):
"""Get the image id of the LoginNodes."""
Expand Down
1 change: 1 addition & 0 deletions cli/src/pcluster/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@
"custom_artifacts_name": "artifacts.zip",
"scheduler_resources_name": "scheduler_resources.zip",
"change_set_name": "change-set.json",
"run_instances_overrides_name": "run_instances_overrides.json",
}

PCLUSTER_TAG_VALUE_REGEX = r"^([\w\+\-\=\.\_\:\@/]{0,256})$"
Expand Down
21 changes: 21 additions & 0 deletions cli/src/pcluster/models/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ def create(
artifact_dir_generated = True
self._upload_config()
self._upload_instance_types_data()
self._upload_run_instances_overrides()
LOGGER.info("Generation and upload completed successfully")

# Create template if not provided by the user
Expand Down Expand Up @@ -558,6 +559,25 @@ def _upload_instance_types_data(self):
e, f"Unable to upload instance types data to the S3 bucket {self.bucket.name} due to exception: {e}"
)

def _upload_run_instances_overrides(self):
"""Upload run_instances_overrides.json to the cluster S3 bucket."""
try:
overrides = self.config.get_run_instances_overrides()
LOGGER.info("Uploading run_instances_overrides.json to S3...")
result = self.bucket.upload_config(
config=overrides,
config_name=PCLUSTER_S3_ARTIFACTS_DICT.get("run_instances_overrides_name"),
format=S3FileFormat.JSON,
)
self.config.run_instances_overrides_version = result.get("VersionId")
LOGGER.info("run_instances_overrides.json uploaded successfully.")
except Exception as e:
raise _cluster_error_mapper(
e,
f"Unable to upload run_instances_overrides.json to the S3 bucket {self.bucket.name} "
f"due to exception: {e}",
)

def _upload_change_set(self, changes=None):
"""Upload change set."""
if changes:
Expand Down Expand Up @@ -924,6 +944,7 @@ def update(
self._add_tags()
self._upload_config()
self._upload_instance_types_data()
self._upload_run_instances_overrides()
self._upload_change_set(changes)

# Create template if not provided by the user
Expand Down
23 changes: 23 additions & 0 deletions cli/src/pcluster/schemas/cluster_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
Image,
Imds,
IntelSoftware,
LaunchTemplateOverrides,
LocalStorage,
LoginNodes,
LoginNodesIam,
Expand Down Expand Up @@ -1536,6 +1537,25 @@ def make_resource(self, data, **kwargs):
return BaseTag(**data)


class LaunchTemplateOverridesSchema(BaseSchema):
"""Represent the schema of the LaunchTemplateOverrides section."""

launch_template_id = fields.Str(
required=True,
validate=validate.Regexp(r"^lt-[a-zA-Z0-9]+$"),
metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},
)
version = fields.Int(
required=True,
metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY},
)

@post_load
def make_resource(self, data, **kwargs):
"""Generate resource."""
return LaunchTemplateOverrides(**data)


class SlurmComputeResourceSchema(_ComputeResourceSchema):
"""Represent the schema of the Slurm ComputeResource."""

Expand Down Expand Up @@ -1576,6 +1596,9 @@ class SlurmComputeResourceSchema(_ComputeResourceSchema):
validate=validate.Range(min=MIN_SLURM_NODE_PRIORITY, max=MAX_SLURM_NODE_PRIORITY),
metadata={"update_policy": UpdatePolicy.SUPPORTED},
)
launch_specification_overrides = fields.Nested(
LaunchTemplateOverridesSchema, metadata={"update_policy": UpdatePolicy.QUEUE_UPDATE_STRATEGY}
)

@validates_schema
def no_coexist_instance_type_flexibility(self, data, **kwargs):
Expand Down
3 changes: 3 additions & 0 deletions cli/src/pcluster/templates/cluster_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,10 +1419,13 @@ def _add_head_node(self):
),
"cluster_config_version": self.config.config_version,
"instance_types_data_version": self.config.instance_types_data_version,
"run_instances_overrides_version": self.config.run_instances_overrides_version,
"change_set_s3_key": f"{self.bucket.artifact_directory}/configs/"
f"{PCLUSTER_S3_ARTIFACTS_DICT.get('change_set_name')}",
"instance_types_data_s3_key": f"{self.bucket.artifact_directory}/configs/"
f"{PCLUSTER_S3_ARTIFACTS_DICT.get('instance_types_data_name')}",
"run_instances_overrides_s3_key": f"{self.bucket.artifact_directory}/configs/"
f"{PCLUSTER_S3_ARTIFACTS_DICT.get('run_instances_overrides_name')}",
"custom_node_package": self.config.custom_node_package or "",
"custom_awsbatchcli_package": self.config.custom_aws_batch_cli_package or "",
"head_node_imds_secured": str(self.config.head_node.imds.secured).lower(),
Expand Down
50 changes: 50 additions & 0 deletions cli/src/pcluster/validators/ec2_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,3 +870,53 @@ def _validate(self, cluster_ultraserver_capacity_block_dict):
f"The following capacity blocks have invalid block sizes: {'; '.join(invalid_capacity_blocks)}.",
FailureLevel.ERROR,
)


class LaunchTemplateOverridesValidator(Validator):
"""Validate the launch template overrides configuration."""

def _validate(self, launch_template_id, version, instance_types, max_network_cards, is_flexible):
try:
lt_data = AWSApi.instance().ec2.describe_launch_template_version(launch_template_id, str(version))
except AWSClientError as e:
self._add_failure(
f"Unable to retrieve launch template {launch_template_id} version {version}. {str(e)}",
FailureLevel.ERROR,
)
return

# Check for properties not in allow list
allow_list = {"InstanceType", "NetworkInterfaces"}
denied_found = [prop for prop in lt_data if prop not in allow_list]
if denied_found:
self._add_failure(
f"Launch template {launch_template_id} contains unsupported properties: "
f"{', '.join(sorted(denied_found))}. Only NetworkInterfaces, InstanceType "
f"are supported in the override launch template.",
FailureLevel.ERROR,
)

# Validate network interface count does not exceed max supported
network_interfaces = lt_data.get("NetworkInterfaces", [])
if network_interfaces and len(network_interfaces) > max_network_cards:
self._add_failure(
f"Launch template {launch_template_id} configures {len(network_interfaces)} network interfaces, "
f"but the instance type supports a maximum of {max_network_cards}.",
FailureLevel.ERROR,
)

# Validate instance type in LT matches the compute resource if specified
lt_instance_type = lt_data.get("InstanceType")
if lt_instance_type and lt_instance_type not in instance_types:
self._add_failure(
f"Instance type '{lt_instance_type}' in launch template {launch_template_id} does not match "
f"the compute resource instance type(s): {', '.join(instance_types)}.",
FailureLevel.ERROR,
)

# Warn if used with flexible instance types
if is_flexible:
self._add_failure(
"LaunchTemplateOverrides cannot be used with flexible instance types.",
FailureLevel.ERROR,
)
3 changes: 3 additions & 0 deletions cli/tests/pcluster/example_configs/slurm.full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ Scheduling:
HttpProxyAddress: https://proxy-address:port
ComputeResources:
- Name: compute-resource-1
LaunchTemplateOverrides:
LaunchTemplateId: lt-0ab6123b7f1111111
Version: "2"
InstanceType: c4.2xlarge
- Name: compute-resource-2
InstanceType: c5.2xlarge
Expand Down
Loading
Loading