Skip to content

Commit

Permalink
add tritongpt validator
Browse files Browse the repository at this point in the history
  • Loading branch information
RockfordMankiniUCSD committed Oct 7, 2024
1 parent 6c14d8b commit 3babee6
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/dsmlp/app/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
GPU_LABEL = "nvidia.com/gpu"
GPU_LIMIT_ANNOTATION = 'gpu-limit'
LOW_PRIORITY_CLASS = "low"
LOW_PRIORITY_CLASS = "low"
31 changes: 31 additions & 0 deletions src/dsmlp/app/tritongpt_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
import json
from typing import List, Optional

from dataclasses_json import dataclass_json
from dsmlp.plugin.awsed import AwsedClient, UnsuccessfulRequest
from dsmlp.plugin.console import Console
from dsmlp.plugin.course import ConfigProvider
from dsmlp.plugin.kube import KubeClient, NotFound
import jsonify

from dsmlp.plugin.logger import Logger
from dsmlp.app.types import *
from dsmlp.app.config import *

# used in order to bypass awsed for tritonGPT while still maintaining UID security.
class TritonGPTValidator(ComponentValidator):

def __init__(self, kube: KubeClient, logger: Logger) -> None:
self.kube = kube
self.logger = logger

def validate_pod(self, request: Request):

permitted_uids = self.kube.get_tgpt_uids()
requested_uid = request.object.spec.securityContext.runAsUser

# if request.uid is not in kube.get_tgpt_uids
# return validationfailure
if requested_uid not in permitted_uids:
raise ValidationFailure(f"TritonGPT Validator: user with {permitted_uids} attempted to run a pod as {requested_uid}. Pod denied.")
14 changes: 13 additions & 1 deletion src/dsmlp/app/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
from abc import ABCMeta, abstractmethod
from dsmlp.app.id_validator import IDValidator
from dsmlp.app.gpu_validator import GPUValidator
from dsmlp.app.tritongpt_validator import TritonGPTValidator
from dsmlp.app.types import *
from dsmlp.app.config import *

class Validator:
def __init__(self, awsed: AwsedClient, kube: KubeClient, logger: Logger) -> None:
self.awsed = awsed
self.logger = logger
self.kube = kube
self.component_validators = [IDValidator(awsed, logger), GPUValidator(awsed, kube, logger)]

def validate_request(self, admission_review_json):
Expand Down Expand Up @@ -51,6 +54,15 @@ def handle_request(self, request: Request):
return self.admission_response(request.uid, True, "Allowed")

def validate_pod(self, request: Request):

try:
if(self.kube.get_tgpt_label(request.namespace) == "enabled"):
self.logger.info("Triton GPT Mode Activated. Only running TritonGPT Validator.")
TritonGPTValidator(self.kube, self.logger).validate_pod(request)
return
except:
self.logger.info("Failed to evaluate TGPT label logic. Falling back on regular validator components.")

for component_validator in self.component_validators:
component_validator.validate_pod(request)

Expand All @@ -65,4 +77,4 @@ def admission_response(self, uid, allowed, message):
"message": message
}
}
}
}
18 changes: 18 additions & 0 deletions src/dsmlp/ext/kube.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def get_gpus_in_namespace(self, name: str) -> int:

return gpu_count

def get_tgpt_label(self, name: str) -> str:
api = self.get_policy_api()
v1namespace: V1Namespace = api.read_namespace(name=name)
metadata: V1ObjectMeta = v1namespace.metadata

if metadata is not None and metadata.labels is not None and "tgpt-validator" in metadata.labels:
return metadata.labels["tgpt-validator"]

# TODO: make arbitrary function of getting namespace labels.
def get_tgpt_uids(self, name: str) -> str:
api = self.get_policy_api()
v1namespace: V1Namespace = api.read_namespace(name=name)
metadata: V1ObjectMeta = v1namespace.metadata

# should be comma delimited, i.e. 2000,100,2,20
if metadata is not None and metadata.labels is not None and "permitted-uids" in metadata.labels:
return metadata.labels["permitted-uids"].split(',')

# noinspection PyMethodMayBeStatic

def get_policy_api(self) -> CoreV1Api:
Expand Down
Empty file.

0 comments on commit 3babee6

Please sign in to comment.