Skip to content

Commit 0ee189d

Browse files
authored
Use token provider instead of token wrapper (#1228)
* Use token provider instead of token wrapper * Upgrade to class based syntax * Use Union
1 parent 15af3a8 commit 0ee189d

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

scripts/prepdocslib/embeddings.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
import time
21
from abc import ABC
3-
from typing import List, Optional, Union
2+
from typing import Awaitable, Callable, List, Optional, Union
43
from urllib.parse import urljoin
54

65
import aiohttp
76
import tiktoken
8-
from azure.core.credentials import AccessToken, AzureKeyCredential
7+
from azure.core.credentials import AzureKeyCredential
98
from azure.core.credentials_async import AsyncTokenCredential
9+
from azure.identity.aio import get_bearer_token_provider
1010
from openai import AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
1111
from tenacity import (
1212
AsyncRetrying,
1313
retry_if_exception_type,
1414
stop_after_attempt,
1515
wait_random_exponential,
1616
)
17+
from typing_extensions import TypedDict
1718

1819

1920
class EmbeddingBatch:
@@ -139,28 +140,29 @@ def __init__(
139140
self.open_ai_service = open_ai_service
140141
self.open_ai_deployment = open_ai_deployment
141142
self.credential = credential
142-
self.cached_token: Optional[AccessToken] = None
143143

144144
async def create_client(self) -> AsyncOpenAI:
145+
class AuthArgs(TypedDict, total=False):
146+
api_key: str
147+
azure_ad_token_provider: Callable[[], Union[str, Awaitable[str]]]
148+
149+
auth_args = AuthArgs()
150+
if isinstance(self.credential, AzureKeyCredential):
151+
auth_args["api_key"] = self.credential.key
152+
elif isinstance(self.credential, AsyncTokenCredential):
153+
auth_args["azure_ad_token_provider"] = get_bearer_token_provider(
154+
self.credential, "https://cognitiveservices.azure.com/.default"
155+
)
156+
else:
157+
raise TypeError("Invalid credential type")
158+
145159
return AsyncAzureOpenAI(
146160
azure_endpoint=f"https://{self.open_ai_service}.openai.azure.com",
147161
azure_deployment=self.open_ai_deployment,
148-
api_key=await self.wrap_credential(),
149162
api_version="2023-05-15",
163+
**auth_args,
150164
)
151165

152-
async def wrap_credential(self) -> str:
153-
if isinstance(self.credential, AzureKeyCredential):
154-
return self.credential.key
155-
156-
if isinstance(self.credential, AsyncTokenCredential):
157-
if not self.cached_token or self.cached_token.expires_on <= time.time():
158-
self.cached_token = await self.credential.get_token("https://cognitiveservices.azure.com/.default")
159-
160-
return self.cached_token.token
161-
162-
raise TypeError("Invalid credential type")
163-
164166

165167
class OpenAIEmbeddingService(OpenAIEmbeddings):
166168
"""

0 commit comments

Comments
 (0)