Skip to content

Commit

Permalink
content moderation and metaguideline
Browse files Browse the repository at this point in the history
  • Loading branch information
somnathkumar7 committed Jul 26, 2024
1 parent b661354 commit 844e5e1
Show file tree
Hide file tree
Showing 12 changed files with 228 additions and 19 deletions.
1 change: 1 addition & 0 deletions configs/llm_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
azure_open_ai:
api_key: "**"
use_azure_ad: True
api_version: "2023-03-15-preview"
api_type: "azure"
azure_endpoint: "https://**.openai.azure.com/"
Expand Down
12 changes: 11 additions & 1 deletion configs/setup_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,22 @@ description:
# content_moderation is optional. If you don't want to specify content_moderation,
# remove this entire field & its sub-fields.
content_moderation:
# If set to true, then content_moderation would be enabled.
enable_moderation: true
# If content_severity crosses this threshold then the processing would be blocked
content_severity_threshold: 2
# If jailbreak detection is required.
jailbreak_detection: true
# Additional guidelines to constrain any malicious behavior
include_metaprompt_guidelines: true
# Azure AI Content Safety (AACS) Endpoint
aacs:
subscription_id: "**"
resource_group: "**"
name: "vellm_aacs"
name: "resource_name"
location: "east us"
sku_name: "S0"
use_azure_ad: true
## Specify the authentication method to be used for AACS
## if use_azure_ad is set to True, then use the following resource will be accessed using Azure AD
## else the resource will be accessed using Key based authentication, key is automatically fetched. No need to pass the key.
5 changes: 5 additions & 0 deletions src/glue-common/glue/common/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class AzureAOIModels(LLMModel, UniversalBaseClass):
@dataclass
class AzureAOILM(UniversalBaseClass):
api_key: str
use_azure_ad: bool
api_version: str
api_type: str
azure_endpoint: str
Expand Down Expand Up @@ -117,12 +118,16 @@ class AACS:
name: str
location: str
sku_name: str
use_azure_ad: bool


@dataclass
class ContentModeration(UniversalBaseClass):
# Class for all content moderation handles
content_severity_threshold: int
enable_moderation: bool
jailbreak_detection: bool
include_metaprompt_guidelines: bool
aacs: AACS

def __post_init__(self):
Expand Down
1 change: 1 addition & 0 deletions src/glue-common/glue/common/constants/str_literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ class DirNames:
@dataclass
class URLs:
AZ_CREDENTIAL_URL= "https://management.azure.com/.default"
AZ_COGNITIVE_SERVICES_URL = "https://cognitiveservices.azure.com/.default"
19 changes: 15 additions & 4 deletions src/glue-common/glue/common/content_moderation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from glue.common.base_classes import SetupConfig
from glue.common.content_moderation.aacs import AACSContentModeration
from glue.common.content_moderation.base_class import ContentModeration

class ByPassContentModeration(ContentModeration):
def __init__(self, setup_config: SetupConfig):
self.setup_config = setup_config
self.include_metaprompt_guidelines = setup_config.content_moderation.include_metaprompt_guidelines

def get_content_moderator_handle(setup_config: SetupConfig):
content_moderator_handle = {}
if setup_config.content_moderation.aacs:
content_moderator_handle["aacs"] = AACSContentModeration(setup_config)
def is_text_safe(self, text) -> bool:
return True

def get_content_moderator_handle(setup_config: SetupConfig) -> ContentModeration:
content_moderator_handle = None
if setup_config.content_moderation.enable_moderation:
if setup_config.content_moderation.aacs:
content_moderator_handle = AACSContentModeration(setup_config)
if content_moderator_handle is None:
content_moderator_handle = ByPassContentModeration(setup_config)

return content_moderator_handle
163 changes: 155 additions & 8 deletions src/glue-common/glue/common/content_moderation/aacs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Dict
import requests
import json

from glue.common.base_classes import AACS, SetupConfig, OperationMode
from glue.common.constants.str_literals import InstallLibs, URLs
from glue.common.exceptions import GlueAuthenticationException
Expand All @@ -18,31 +22,174 @@
from azure.mgmt.cognitiveservices.models import Account, Sku, AccountProperties


def sliding_window(text, max_chars=1024, overlap_words=2):
words = text.split()
windows = []
start = 0
while start < len(words):
window = []
current_length = 0
for i in range(start, len(words)):
word_length = len(words[i]) + 1
if current_length + word_length > max_chars:
break
window.append(words[i])
current_length += word_length
windows.append(' '.join(window))
start += len(window) - overlap_words
return windows

def merge_dicts(dict_a, dict_b):
if not isinstance(dict_a, dict) or not isinstance(dict_b, dict):
return [dict_a, dict_b] if dict_a != dict_b else dict_a

merged = dict_a.copy()
for key, value in dict_b.items():
if key in dict_a:
if isinstance(dict_a[key], dict) and isinstance(value, dict):
merged[key] = merge_dicts(dict_a[key], value)
elif isinstance(dict_a[key], list):
merged[key] += value if isinstance(value, list) else [value]
elif isinstance(value, list):
merged[key] = [dict_a[key]] + value
elif isinstance(dict_a[key], (int, float, str)) and isinstance(value, (int, float, str)):
merged[key] = [dict_a[key], value] if dict_a[key] != value else dict_a[key]
else:
merged[key] = value
else:
merged[key] = value
return merged

class AACSContentModeration(ContentModeration):
def __init__(self, setup_config: SetupConfig):
self.setup_config = setup_config
aacs_config = setup_config.content_moderation.aacs
self.include_metaprompt_guidelines = setup_config.content_moderation.include_metaprompt_guidelines
try:
credential = DefaultAzureCredential()
credential.get_token(URLs.AZ_CREDENTIAL_URL)
self.auth_token = credential.get_token(URLs.AZ_COGNITIVE_SERVICES_URL).token
except Exception as e:
if setup_config.mode == OperationMode.OFFLINE.value:
credential = InteractiveBrowserCredential()
self.auth_token = credential.get_token(URLs.AZ_COGNITIVE_SERVICES_URL).token
else:
raise GlueAuthenticationException(f"For using DefaultAzureCredential to authenticate config.json needs to "
f"be present. Refer: https://learn.microsoft.com/en-us/azure/machine-learning/how-to-configure-environment?view=azureml-api-2"
f"If running in offline mode InteractiveBrowserCredential() can be used to authentication using pop-up window in web browser.\n", e)

aacs_client = CognitiveServicesManagementClient(credential, aacs_config.subscription_id)
aacs = self.find_or_create_aacs(aacs_config, aacs_client)
aacs_access_key = aacs_client.accounts.list_keys(
resource_group_name=aacs_config.resource_group, account_name=aacs.name).key1
self.aacs_client = ContentSafetyClient(aacs.properties.endpoint, AzureKeyCredential(aacs_access_key))
self.aacs = self.find_or_create_aacs(aacs_config, aacs_client)
if aacs_config.use_azure_ad:
self.aacs_access_key = None
else:
self.aacs_access_key = aacs_client.accounts.list_keys(
resource_group_name=aacs_config.resource_group, account_name=self.aacs.name).key1
self.auth_token = None

def check_attack_detected(self, result):
if isinstance(result, dict):
for value in result.values():
if isinstance(value, bool) and value:
return True
elif isinstance(value, dict):
if self.check_attack_detected(value):
return True
elif isinstance(result, list):
for item in result:
if self.check_attack_detected(item):
return True
return False

def shield_prompt(self,
user_prompt: str,
documents: list
) -> dict:
"""
Detects unsafe content using the Content Safety API.
def is_text_safe(self, text) -> bool:
response = self.aacs_client.analyze_text(AnalyzeTextOptions(text=text))
return self.is_below_threshold(response["categoriesAnalysis"])
Args:
- user_prompt (str): The user prompt to analyze.
- documents (list): The documents to analyze.
Returns:
- dict: The response from the Content Safety API.
"""

api_version = "2024-02-15-preview"
url = f"{self.aacs.properties.endpoint}/contentsafety/text:shieldPrompt?api-version={api_version}"
headers = self.build_headers()
if len(user_prompt) <= 110:
user_prompt += " " + "_"*(110-len(user_prompt))
if isinstance(documents, list) and len(documents) > 0:
for i in range(len(documents)):
if len(documents[i]) <= 110:
documents[i] += " " + "_"*(110-len(documents[i]))
else:
if len(documents) <= 110:
documents += " " + "_"*(110-len(documents))

data = {
"userPrompt": user_prompt,
"documents": documents
}
response = requests.post(url, headers=headers, json=data)
return self.check_attack_detected(response.json())

def text_analyze(self, content, blocklists=[]):
api_version = "2023-10-01"
url = f"{self.aacs.properties.endpoint}/contentsafety/text:analyze?api-version={api_version}"
headers = self.build_headers()
results = None
for windows in sliding_window(content, 9990, 0):
if len(windows)<=110:
windows += " " + "_"*(110-len(windows))
request_body = {
"text": windows,
"blocklistNames": blocklists,
}
payload = json.dumps(request_body)

response = requests.post(url, headers=headers, data=payload)
res_content = response.json()
if response.status_code != 200:
raise Exception(
res_content
)

if results is None:
results = res_content
else:
results = merge_dicts(results, res_content)
return results


def build_headers(self) -> Dict[str, str]:
"""
Builds the headers for the Content Safety API request.
Returns:
- Dict[str, str]: The headers for the Content Safety API request.
"""
if not self.aacs_access_key:
return {
"Authorization": "Bearer "+ self.auth_token,
"Content-Type": "application/json",
}
else:
return {
"Ocp-Apim-Subscription-Key": self.aacs_access_key,
"Content-Type": "application/json",
}

def is_text_safe(self, text) -> bool:
harm_response = self.text_analyze(text)
safe_bool = self.is_below_threshold(harm_response["categoriesAnalysis"])
if not safe_bool:
return safe_bool
jailbreak_bool = self.shield_prompt(text, [])
safe_bool = safe_bool and (not jailbreak_bool)
return safe_bool

def is_below_threshold(self, category_list) -> bool:
"""
Based on response received from AACS, check if the threshold of content_severity, set by user is crossed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class ContentModeration:
"""
def __init__(self, setup_config: SetupConfig):
self.setup_config = setup_config
self.include_metaprompt_guidelines = setup_config.content_moderation.include_metaprompt_guidelines
pass

def is_text_safe(self, text) -> bool:
Expand Down
12 changes: 10 additions & 2 deletions src/glue-common/glue/common/llm/llm_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from llama_index.core.llms import LLM
from tenacity import retry, stop_after_attempt, wait_fixed, wait_random

from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential

from glue.common.base_classes import LLMConfig
from glue.common.constants.str_literals import InstallLibs, OAILiterals, \
OAILiterals, LLMLiterals, LLMOutputTypes
OAILiterals, LLMLiterals, LLMOutputTypes, URLs
from glue.common.llm.llm_helper import get_token_counter
from glue.common.exceptions import GlueLLMException
from glue.common.exceptions import GlueLLMException, GlueAuthenticationException
from glue.common.utils.runtime_tasks import install_lib_if_missing
from glue.common.utils.logging import get_glue_logger
from glue.common.utils.runtime_tasks import str_to_class
Expand Down Expand Up @@ -64,6 +66,9 @@ def get_llm_pool(llm_config: LLMConfig) -> Dict[str, LLM]:
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.multi_modal_llms.azure_openai import AzureOpenAIMultiModal

if az_llm_config.use_azure_ad:
az_llm_config.api_key = None

for azure_oai_model in az_llm_config.azure_oai_models:
callback_mgr = None
if azure_oai_model.track_tokens:
Expand All @@ -79,6 +84,7 @@ def get_llm_pool(llm_config: LLMConfig) -> Dict[str, LLM]:
AzureOpenAI(model=azure_oai_model.model_name_in_azure,
deployment_name=azure_oai_model.deployment_name_in_azure,
api_key=az_llm_config.api_key,
use_azure_ad=az_llm_config.use_azure_ad,
azure_endpoint=az_llm_config.azure_endpoint,
api_version=az_llm_config.api_version,
callback_manager=callback_mgr
Expand All @@ -88,6 +94,7 @@ def get_llm_pool(llm_config: LLMConfig) -> Dict[str, LLM]:
AzureOpenAIEmbedding(model=azure_oai_model.model_name_in_azure,
deployment_name=azure_oai_model.deployment_name_in_azure,
api_key=az_llm_config.api_key,
use_azure_ad=az_llm_config.use_azure_ad,
azure_endpoint=az_llm_config.azure_endpoint,
api_version=az_llm_config.api_version,
callback_manager=callback_mgr
Expand All @@ -98,6 +105,7 @@ def get_llm_pool(llm_config: LLMConfig) -> Dict[str, LLM]:
AzureOpenAIMultiModal(model=azure_oai_model.model_name_in_azure,
deployment_name=azure_oai_model.deployment_name_in_azure,
api_key=az_llm_config.api_key,
use_azure_ad=az_llm_config.use_azure_ad,
azure_endpoint=az_llm_config.azure_endpoint,
api_version=az_llm_config.api_version,
max_new_tokens=4096
Expand Down
7 changes: 5 additions & 2 deletions src/glue-promptopt/glue/promptopt/instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from glue.common.base_classes import LLMConfig, SetupConfig
from glue.common.constants.log_strings import CommonLogsStr
from glue.common.llm.llm_mgr import LLMMgr
from glue.common.content_moderation import get_content_moderator_handle
from glue.common.utils.logging import get_glue_logger, set_logging_config
from glue.common.utils.file import read_jsonl, yaml_to_class, yaml_to_dict, read_jsonl_row
from paramlogger import ParamLogger
Expand Down Expand Up @@ -72,7 +73,9 @@ def __init__(self, llm_config_path: str,
self.prompt_pool = yaml_to_class(prompt_pool_path, promptpool_cls, default_yaml_path)
llm_config = yaml_to_class(llm_config_path, LLMConfig)
llm_pool = LLMMgr.get_llm_pool(llm_config)


content_moderator = get_content_moderator_handle(self.setup_config)

dataset = read_jsonl(dataset_jsonl)
training_dataset = dataset[:self.prompt_opt_param.seen_set_size]
self.prompt_opt_param.answer_format += self.prompt_pool.ans_delimiter_instruction
Expand All @@ -89,7 +92,7 @@ def __init__(self, llm_config_path: str,
# This iolog is going to be used when doing complete evaluation over test-dataset
self.iolog.reset_eval_glue(join(base_path, "evaluation"))

self.prompt_opt = prompt_opt_cls(training_dataset, base_path, llm_pool, self.setup_config,
self.prompt_opt = prompt_opt_cls(training_dataset, base_path, llm_pool, content_moderator, self.setup_config,
self.prompt_pool, self.data_processor, self.logger)

def get_best_prompt(self) -> (str, Any):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CritiqueNRefinePromptPool(PromptPool):
expert_template: str
generate_reason_template: str
reason_optimization_template: str
metaprompt_guidelines: str


@dataclass
Expand Down
Loading

0 comments on commit 844e5e1

Please sign in to comment.