Skip to content

Commit

Permalink
Add PCS support
Browse files Browse the repository at this point in the history
Resolves #286
  • Loading branch information
cartalla committed Nov 26, 2024
1 parent 0148190 commit 2fb2a54
Show file tree
Hide file tree
Showing 6 changed files with 575 additions and 200 deletions.
567 changes: 371 additions & 196 deletions source/cdk/cdk_slurm_stack.py

Large diffs are not rendered by default.

28 changes: 24 additions & 4 deletions source/cdk/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@
}

def get_parallel_cluster_version(config):
parallel_cluster_version = config['slurm']['ParallelClusterConfig']['Version']
parallel_cluster_version = config['slurm'].get('ParallelClusterConfig', {}).get('Version', PARALLEL_CLUSTER_VERSIONS[-1])
if parallel_cluster_version not in PARALLEL_CLUSTER_VERSIONS:
logger.error(f"Unsupported ParallelCluster version: {parallel_cluster_version}\nSupported versions are:\n{json.dumps(PARALLEL_CLUSTER_VERSIONS, indent=4)}")
raise KeyError(parallel_cluster_version)
Expand Down Expand Up @@ -329,6 +329,19 @@ def get_PARALLEL_CLUSTER_LAMBDA_RUNTIME(parallel_cluster_version):
else:
return aws_lambda.Runtime.PYTHON_3_12

PCS_SLURM_VERSIONS = [
'23.11'
]

PCS_CONTROLLER_SIZES = [
'Small',
'Medium',
'Large'
]

def get_PCS_LAMBDA_RUNTIME():
return aws_lambda.Runtime.PYTHON_3_12

# Determine all AWS regions available on the account.
default_region = environ.get("AWS_DEFAULT_REGION", "us-east-1")
ec2_client = boto3.client("ec2", region_name=default_region)
Expand All @@ -350,7 +363,7 @@ def get_PARALLEL_CLUSTER_LAMBDA_RUNTIME(parallel_cluster_version):
DEFAULT_X86_CONTROLLER_INSTANCE_TYPE = 'c6a.large'

def default_controller_instance_type(config):
architecture = config['slurm']['ParallelClusterConfig'].get('Architecture', DEFAULT_ARCHITECTURE)
architecture = config['slurm'].get('ParallelClusterConfig', {}).get('Architecture', DEFAULT_ARCHITECTURE)
if architecture == 'x86_64':
return DEFAULT_X86_CONTROLLER_INSTANCE_TYPE
elif architecture == 'arm64':
Expand All @@ -363,7 +376,7 @@ def default_controller_instance_type(config):
DEFAULT_X86_OS = 'rhel8'

def DEFAULT_OS(config):
architecture = config['slurm']['ParallelClusterConfig'].get('Architecture', DEFAULT_ARCHITECTURE)
architecture = config['slurm'].get('ParallelClusterConfig', {}).get('Architecture', DEFAULT_ARCHITECTURE)
if architecture == 'x86_64':
return DEFAULT_X86_OS
elif architecture == 'arm64':
Expand Down Expand Up @@ -1252,9 +1265,16 @@ def get_config_schema(config):
]
}
},
Optional('PcsConfig'): {
'SlurmVersion': And(str, lambda s: s in PCS_SLURM_VERSIONS),
'ControllerSize': And(str, lambda s: s in PCS_CONTROLLER_SIZES),
Optional('Tags'): [
{'Key': str, 'Value': str}
]
},
#
# ClusterName:
# Name of the ParallelCluster cluster.
# Name of the ParallelCluster or PCS cluster.
# Default:
# If StackName ends with "-config" then ClusterName is StackName with "-config" stripped off.
# Otherwise add "-cl" to end of StackName.
Expand Down
90 changes: 90 additions & 0 deletions source/resources/lambdas/PcsCluster/PcsCluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

'''
Create/update/delete ParallelCluster cluster and save config to S3 as json and yaml.
'''
import boto3
import cfnresponse
import json
import logging
from os import environ as environ
from time import sleep

logger=logging.getLogger(__file__)
logger_formatter = logging.Formatter('%(levelname)s: %(message)s')
logger_streamHandler = logging.StreamHandler()
logger_streamHandler.setFormatter(logger_formatter)
logger.addHandler(logger_streamHandler)
logger.setLevel(logging.INFO)
logger.propagate = False

def lambda_handler(event, context):
try:
logger.info(f"event:\n{json.dumps(event, indent=4)}")

# Create sns client so can send notifications on any errors.
sns_client = boto3.client('sns')

cluster_name = None
requestType = event['RequestType']
properties = event['ResourceProperties']
required_properties = [
]
error_message = ""
for property in required_properties:
try:
value = properties[property]
except:
error_message += f"Missing {property} property. "
if error_message:
logger.info(error_message)
if requestType == 'Delete':
cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
return
else:
raise KeyError(error_message)

cluster_name = environ['ClusterName']
cluster_region = environ['Region']
logger.info(f"{requestType} request for {cluster_name} in {cluster_region}")

# pcs_client = boto3.client('pcs')

if requestType == 'Create':
logger.info(f"Creating {cluster_name}")
elif requestType == 'Update':
logger.info(f"Updating {cluster_name}")
elif requestType == 'Delete':
logger.info(f"Deleting {cluster_name}")

cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId="resource_id")
return

except Exception as e:
logger.exception(str(e))
cfnresponse.send(event, context, cfnresponse.FAILED, {'error': str(e)}, physicalResourceId=cluster_name)
sns_client.publish(
TopicArn = environ['ErrorSnsTopicArn'],
Subject = f"{cluster_name} CreateParallelCluster failed",
Message = str(e)
)
logger.info(f"Published error to {environ['ErrorSnsTopicArn']}")
raise

cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
1 change: 1 addition & 0 deletions source/resources/lambdas/PcsCluster/cfnresponse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify,
merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

'''
Create/update/delete ParallelCluster cluster and save config to S3 as json and yaml.
'''
import boto3
import cfnresponse
import json
import logging
from os import environ as environ
from time import sleep

logger=logging.getLogger(__file__)
logger_formatter = logging.Formatter('%(levelname)s: %(message)s')
logger_streamHandler = logging.StreamHandler()
logger_streamHandler.setFormatter(logger_formatter)
logger.addHandler(logger_streamHandler)
logger.setLevel(logging.INFO)
logger.propagate = False

def lambda_handler(event, context):
try:
logger.info(f"event:\n{json.dumps(event, indent=4)}")

# Create sns client so can send notifications on any errors.
sns_client = boto3.client('sns')

cluster_name = None
requestType = event['RequestType']
properties = event['ResourceProperties']
required_properties = [
]
error_message = ""
for property in required_properties:
try:
value = properties[property]
except:
error_message += f"Missing {property} property. "
if error_message:
logger.info(error_message)
if requestType == 'Delete':
cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
return
else:
raise KeyError(error_message)

cluster_name = environ['ClusterName']
cluster_region = environ['Region']
logger.info(f"{requestType} request for {cluster_name} in {cluster_region}")

if requestType == 'Create':
logger.info(f"Creating {cluster_name}")
elif requestType == 'Update':
logger.info(f"Updating {cluster_name}")
elif requestType == 'Delete':
logger.info(f"Deleting {cluster_name}")

cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)
return

except Exception as e:
logger.exception(str(e))
cfnresponse.send(event, context, cfnresponse.FAILED, {'error': str(e)}, physicalResourceId=cluster_name)
sns_client.publish(
TopicArn = environ['ErrorSnsTopicArn'],
Subject = f"{cluster_name} CreateParallelCluster failed",
Message = str(e)
)
logger.info(f"Published error to {environ['ErrorSnsTopicArn']}")
raise

cfnresponse.send(event, context, cfnresponse.SUCCESS, {}, physicalResourceId=cluster_name)

0 comments on commit 2fb2a54

Please sign in to comment.