Skip to content

Commit

Permalink
Use AAD instead of key vault for Computer Vision API (#1062)
Browse files Browse the repository at this point in the history
* Initial changes to remove keyvault and use AAD instead

* rm keyvault

* Fix Bicep

* Role rename

* Make mypy happy
  • Loading branch information
pamelafox authored Feb 24, 2024
1 parent 3424475 commit 2e79777
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 77 deletions.
13 changes: 5 additions & 8 deletions app/backend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ async def setup_clients():
AZURE_SEARCH_SERVICE = os.environ["AZURE_SEARCH_SERVICE"]
AZURE_SEARCH_INDEX = os.environ["AZURE_SEARCH_INDEX"]
SEARCH_SECRET_NAME = os.getenv("SEARCH_SECRET_NAME")
VISION_SECRET_NAME = os.getenv("VISION_SECRET_NAME")
AZURE_KEY_VAULT_NAME = os.getenv("AZURE_KEY_VAULT_NAME")
# Shared by all OpenAI deployments
OPENAI_HOST = os.getenv("OPENAI_HOST", "azure")
Expand Down Expand Up @@ -257,13 +256,11 @@ async def setup_clients():
azure_credential = DefaultAzureCredential(exclude_shared_token_cache_credential=True)

# Fetch any necessary secrets from Key Vault
vision_key = None
search_key = None
if AZURE_KEY_VAULT_NAME and (VISION_SECRET_NAME or SEARCH_SECRET_NAME):
if AZURE_KEY_VAULT_NAME:
key_vault_client = SecretClient(
vault_url=f"https://{AZURE_KEY_VAULT_NAME}.vault.azure.net", credential=azure_credential
)
vision_key = VISION_SECRET_NAME and (await key_vault_client.get_secret(VISION_SECRET_NAME)).value
search_key = SEARCH_SECRET_NAME and (await key_vault_client.get_secret(SEARCH_SECRET_NAME)).value
await key_vault_client.close()

Expand Down Expand Up @@ -348,16 +345,16 @@ async def setup_clients():
)

if USE_GPT4V:
if vision_key is None:
raise ValueError("Vision key must be set (in Key Vault) to use the vision approach.")

token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")

current_app.config[CONFIG_ASK_VISION_APPROACH] = RetrieveThenReadVisionApproach(
search_client=search_client,
openai_client=openai_client,
blob_container_client=blob_container_client,
auth_helper=auth_helper,
vision_endpoint=AZURE_VISION_ENDPOINT,
vision_key=vision_key,
vision_token_provider=token_provider,
gpt4v_deployment=AZURE_OPENAI_GPT4V_DEPLOYMENT,
gpt4v_model=AZURE_OPENAI_GPT4V_MODEL,
embedding_model=OPENAI_EMB_MODEL,
Expand All @@ -374,7 +371,7 @@ async def setup_clients():
blob_container_client=blob_container_client,
auth_helper=auth_helper,
vision_endpoint=AZURE_VISION_ENDPOINT,
vision_key=vision_key,
vision_token_provider=token_provider,
gpt4v_deployment=AZURE_OPENAI_GPT4V_DEPLOYMENT,
gpt4v_model=AZURE_OPENAI_GPT4V_MODEL,
embedding_model=OPENAI_EMB_MODEL,
Expand Down
18 changes: 13 additions & 5 deletions app/backend/approaches/approach.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from abc import ABC
from dataclasses import dataclass
from typing import Any, AsyncGenerator, List, Optional, Union, cast
from typing import Any, AsyncGenerator, Awaitable, Callable, List, Optional, Union, cast
from urllib.parse import urljoin

import aiohttp
from azure.search.documents.aio import SearchClient
Expand Down Expand Up @@ -74,7 +76,7 @@ class ThoughtStep:
props: Optional[dict[str, Any]] = None


class Approach:
class Approach(ABC):
def __init__(
self,
search_client: SearchClient,
Expand All @@ -85,6 +87,8 @@ def __init__(
embedding_deployment: Optional[str], # Not needed for non-Azure OpenAI or for retrieval_mode="text"
embedding_model: str,
openai_host: str,
vision_endpoint: str,
vision_token_provider: Callable[[], Awaitable[str]],
):
self.search_client = search_client
self.openai_client = openai_client
Expand All @@ -94,6 +98,8 @@ def __init__(
self.embedding_deployment = embedding_deployment
self.embedding_model = embedding_model
self.openai_host = openai_host
self.vision_endpoint = vision_endpoint
self.vision_token_provider = vision_token_provider

def build_filter(self, overrides: dict[str, Any], auth_claims: dict[str, Any]) -> Optional[str]:
exclude_category = overrides.get("exclude_category")
Expand Down Expand Up @@ -188,12 +194,14 @@ async def compute_text_embedding(self, q: str):
query_vector = embedding.data[0].embedding
return VectorizedQuery(vector=query_vector, k_nearest_neighbors=50, fields="embedding")

async def compute_image_embedding(self, q: str, vision_endpoint: str, vision_key: str):
endpoint = f"{vision_endpoint}computervision/retrieval:vectorizeText"
async def compute_image_embedding(self, q: str):
endpoint = urljoin(self.vision_endpoint, "computervision/retrieval:vectorizeText")
headers = {"Content-Type": "application/json"}
params = {"api-version": "2023-02-01-preview", "modelVersion": "latest"}
headers = {"Content-Type": "application/json", "Ocp-Apim-Subscription-Key": vision_key}
data = {"text": q}

headers["Authorization"] = "Bearer " + await self.vision_token_provider()

async with aiohttp.ClientSession() as session:
async with session.post(
url=endpoint, params=params, headers=headers, json=data, raise_for_status=True
Expand Down
8 changes: 4 additions & 4 deletions app/backend/approaches/chatreadretrievereadvision.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Coroutine, Optional, Union
from typing import Any, Awaitable, Callable, Coroutine, Optional, Union

from azure.search.documents.aio import SearchClient
from azure.storage.blob.aio import ContainerClient
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
query_language: str,
query_speller: str,
vision_endpoint: str,
vision_key: str,
vision_token_provider: Callable[[], Awaitable[str]]
):
self.search_client = search_client
self.blob_container_client = blob_container_client
Expand All @@ -55,7 +55,7 @@ def __init__(
self.query_language = query_language
self.query_speller = query_speller
self.vision_endpoint = vision_endpoint
self.vision_key = vision_key
self.vision_token_provider = vision_token_provider
self.chatgpt_token_limit = get_token_limit(gpt4v_model)

@property
Expand Down Expand Up @@ -126,7 +126,7 @@ async def run_until_final_call(
vector = (
await self.compute_text_embedding(query_text)
if field == "embedding"
else await self.compute_image_embedding(query_text, self.vision_endpoint, self.vision_key)
else await self.compute_image_embedding(query_text)
)
vectors.append(vector)

Expand Down
8 changes: 4 additions & 4 deletions app/backend/approaches/retrievethenreadvision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, AsyncGenerator, Awaitable, Callable, Optional, Union

from azure.search.documents.aio import SearchClient
from azure.storage.blob.aio import ContainerClient
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
query_language: str,
query_speller: str,
vision_endpoint: str,
vision_key: str,
vision_token_provider: Callable[[], Awaitable[str]]
):
self.search_client = search_client
self.blob_container_client = blob_container_client
Expand All @@ -68,7 +68,7 @@ def __init__(
self.query_language = query_language
self.query_speller = query_speller
self.vision_endpoint = vision_endpoint
self.vision_key = vision_key
self.vision_token_provider = vision_token_provider

async def run(
self,
Expand Down Expand Up @@ -100,7 +100,7 @@ async def run(
vector = (
await self.compute_text_embedding(q)
if field == "embedding"
else await self.compute_image_embedding(q, self.vision_endpoint, self.vision_key)
else await self.compute_image_embedding(q)
)
vectors.append(vector)

Expand Down
45 changes: 23 additions & 22 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ var actualSearchServiceSemanticRankerLevel = (searchServiceSkuName == 'free') ?
param useSearchServiceKey bool = searchServiceSkuName == 'free'

param storageAccountName string = ''
param keyVaultResourceGroupName string = ''
param storageResourceGroupName string = ''
param storageResourceGroupLocation string = location
param storageContainerName string = 'content'
Expand All @@ -47,12 +46,12 @@ param openAiServiceName string = ''
param openAiResourceGroupName string = ''
param useGPT4V bool = false

param keyVaultResourceGroupName string = ''
param keyVaultServiceName string = ''
param computerVisionSecretName string = 'computerVisionSecret'
param searchServiceSecretName string = 'searchServiceSecret'

@description('Location for the OpenAI resource group')
@allowed(['canadaeast', 'eastus', 'eastus2', 'francecentral', 'switzerlandnorth', 'uksouth', 'japaneast', 'northcentralus', 'australiaeast', 'swedencentral'])
@allowed([ 'canadaeast', 'eastus', 'eastus2', 'francecentral', 'switzerlandnorth', 'uksouth', 'japaneast', 'northcentralus', 'australiaeast', 'swedencentral' ])
@metadata({
azd: {
type: 'location'
Expand All @@ -70,7 +69,7 @@ param documentIntelligenceResourceGroupName string = ''
// Limited regions for new version:
// https://learn.microsoft.com/azure/ai-services/document-intelligence/concept-layout
@description('Location for the Document Intelligence resource group')
@allowed(['eastus', 'westus2', 'westeurope'])
@allowed([ 'eastus', 'westus2', 'westeurope' ])
@metadata({
azd: {
type: 'location'
Expand Down Expand Up @@ -129,7 +128,7 @@ var resourceToken = toLower(uniqueString(subscription().id, environmentName, loc
var tags = { 'azd-env-name': environmentName }
var computerVisionName = !empty(computerVisionServiceName) ? computerVisionServiceName : '${abbrs.cognitiveServicesComputerVision}${resourceToken}'

var useKeyVault = useGPT4V || useSearchServiceKey
var useKeyVault = useSearchServiceKey
var tenantIdForAuth = !empty(authTenantId) ? authTenantId : tenantId
var authenticationIssuerUri = '${environment().authentication.loginEndpoint}${tenantIdForAuth}/v2.0'

Expand Down Expand Up @@ -182,7 +181,6 @@ module monitoring 'core/monitor/monitoring.bicep' = if (useApplicationInsights)
}
}


module applicationInsightsDashboard 'backend-dashboard.bicep' = if (useApplicationInsights) {
name: 'application-insights-dashboard'
scope: resourceGroup
Expand All @@ -193,7 +191,6 @@ module applicationInsightsDashboard 'backend-dashboard.bicep' = if (useApplicati
}
}


// Create an App Service Plan to group applications under the same payment plan and SKU
module appServicePlan 'core/host/appserviceplan.bicep' = {
name: 'appserviceplan'
Expand Down Expand Up @@ -224,7 +221,7 @@ module backend 'core/host/appservice.bicep' = {
appCommandLine: 'python3 -m gunicorn main:app'
scmDoBuildDuringDeployment: true
managedIdentity: true
allowedOrigins: [allowedOrigin]
allowedOrigins: [ allowedOrigin ]
clientAppId: clientAppId
serverAppId: serverAppId
clientSecretSettingName: !empty(clientAppSecret) ? 'AZURE_CLIENT_APP_SECRET' : ''
Expand All @@ -238,7 +235,6 @@ module backend 'core/host/appservice.bicep' = {
AZURE_SEARCH_SERVICE: searchService.outputs.name
AZURE_SEARCH_SEMANTIC_RANKER: actualSearchServiceSemanticRankerLevel
AZURE_VISION_ENDPOINT: useGPT4V ? computerVision.outputs.endpoint : ''
VISION_SECRET_NAME: useGPT4V ? computerVisionSecretName: ''
SEARCH_SECRET_NAME: useSearchServiceKey ? searchServiceSecretName : ''
AZURE_KEY_VAULT_NAME: useKeyVault ? keyVault.outputs.name : ''
AZURE_SEARCH_QUERY_LANGUAGE: searchQueryLanguage
Expand Down Expand Up @@ -361,9 +357,8 @@ module computerVision 'core/ai/cognitiveservices.bicep' = if (useGPT4V) {
}
}


// Currently, we only need Key Vault for storing Computer Vision key,
// which is only used for GPT-4V.
// Currently, we only need Key Vault for storing Search service key,
// which is only used for free tier
module keyVault 'core/security/keyvault.bicep' = if (useKeyVault) {
name: 'keyvault'
scope: keyVaultResourceGroup
Expand All @@ -388,16 +383,12 @@ module secrets 'secrets.bicep' = if (useKeyVault) {
scope: keyVaultResourceGroup
params: {
keyVaultName: useKeyVault ? keyVault.outputs.name : ''
storeComputerVisionSecret: useGPT4V
computerVisionId: useGPT4V ? computerVision.outputs.id : ''
computerVisionSecretName: computerVisionSecretName
storeSearchServiceSecret: useSearchServiceKey
searchServiceId: useSearchServiceKey ? searchService.outputs.id : ''
searchServiceSecretName: searchServiceSecretName
}
}


module searchService 'core/search/search-services.bicep' = {
name: 'search-service'
scope: searchServiceResourceGroup
Expand Down Expand Up @@ -443,7 +434,7 @@ module storage 'core/storage/storage-account.bicep' = {
}

// USER ROLES
var principalType = empty(runningOnGh) && empty(runningOnAdo) ? 'User': 'ServicePrincipal'
var principalType = empty(runningOnGh) && empty(runningOnAdo) ? 'User' : 'ServicePrincipal'

module openAiRoleUser 'core/security/role.bicep' = if (openAiHost == 'azure') {
scope: openAiResourceGroup
Expand All @@ -455,9 +446,10 @@ module openAiRoleUser 'core/security/role.bicep' = if (openAiHost == 'azure') {
}
}

module documentIntelligenceRoleUser 'core/security/role.bicep' = {
scope: documentIntelligenceResourceGroup
name: 'documentintelligence-role-user'
// For both document intelligence and computer vision
module cognitiveServicesRoleUser 'core/security/role.bicep' = {
scope: resourceGroup
name: 'cognitiveservices-role-user'
params: {
principalId: principalId
roleDefinitionId: 'a97b65f3-24c7-4388-baec-2e87135dc908'
Expand Down Expand Up @@ -537,7 +529,6 @@ module openAiRoleSearchService 'core/security/role.bicep' = if (openAiHost == 'a
}
}


module storageRoleBackend 'core/security/role.bicep' = {
scope: storageResourceGroup
name: 'storage-role-backend'
Expand Down Expand Up @@ -582,6 +573,17 @@ module searchReaderRoleBackend 'core/security/role.bicep' = if (useAuthenticatio
}
}

// For computer vision access by the backend
module cognitiveServicesRoleBackend 'core/security/role.bicep' = if (useGPT4V) {
scope: resourceGroup
name: 'cognitiveservices-role-backend'
params: {
principalId: backend.outputs.identityPrincipalId
roleDefinitionId: 'a97b65f3-24c7-4388-baec-2e87135dc908'
principalType: 'ServicePrincipal'
}
}

output AZURE_LOCATION string = location
output AZURE_TENANT_ID string = tenantId
output AZURE_AUTH_TENANT_ID string = authTenantId
Expand All @@ -605,7 +607,6 @@ output OPENAI_API_KEY string = (openAiHost == 'openai') ? openAiApiKey : ''
output OPENAI_ORGANIZATION string = (openAiHost == 'openai') ? openAiApiOrganization : ''

output AZURE_VISION_ENDPOINT string = useGPT4V ? computerVision.outputs.endpoint : ''
output VISION_SECRET_NAME string = useGPT4V ? computerVisionSecretName : ''
output AZURE_KEY_VAULT_NAME string = useKeyVault ? keyVault.outputs.name : ''

output AZURE_DOCUMENTINTELLIGENCE_SERVICE string = documentIntelligence.outputs.name
Expand Down
11 changes: 0 additions & 11 deletions infra/secrets.bicep
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
param keyVaultName string
param storeComputerVisionSecret bool
param computerVisionId string
param computerVisionSecretName string
param storeSearchServiceSecret bool
param searchServiceId string
param searchServiceSecretName string

module computerVisionKVSecret 'core/security/keyvault-secret.bicep' = if (storeComputerVisionSecret) {
name: 'keyvault-secret'
params: {
keyVaultName: storeComputerVisionSecret ? keyVaultName : ''
name: computerVisionSecretName
secretValue: storeComputerVisionSecret ? listKeys(computerVisionId, '2023-05-01').key1 : ''
}
}

module searchServiceKVSecret 'core/security/keyvault-secret.bicep' = if (storeSearchServiceSecret) {
name: 'searchservice-secret'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ lint.select = ["E", "F", "I", "UP"]
lint.ignore = ["E501", "E701"] # line too long, multiple statements on one line
src = ["app/backend", "scripts"]

[tool.ruff.isort]
[tool.ruff.lint.isort]
known-local-folder = ["scripts"]

[tool.black]
Expand Down
Loading

0 comments on commit 2e79777

Please sign in to comment.