|
1 |
| -import time |
2 | 1 | from abc import ABC
|
3 |
| -from typing import List, Optional, Union |
| 2 | +from typing import Awaitable, Callable, List, Optional, Union |
4 | 3 | from urllib.parse import urljoin
|
5 | 4 |
|
6 | 5 | import aiohttp
|
7 | 6 | import tiktoken
|
8 |
| -from azure.core.credentials import AccessToken, AzureKeyCredential |
| 7 | +from azure.core.credentials import AzureKeyCredential |
9 | 8 | from azure.core.credentials_async import AsyncTokenCredential
|
| 9 | +from azure.identity.aio import get_bearer_token_provider |
10 | 10 | from openai import AsyncAzureOpenAI, AsyncOpenAI, RateLimitError
|
11 | 11 | from tenacity import (
|
12 | 12 | AsyncRetrying,
|
13 | 13 | retry_if_exception_type,
|
14 | 14 | stop_after_attempt,
|
15 | 15 | wait_random_exponential,
|
16 | 16 | )
|
| 17 | +from typing_extensions import TypedDict |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class EmbeddingBatch:
|
@@ -139,28 +140,29 @@ def __init__(
|
139 | 140 | self.open_ai_service = open_ai_service
|
140 | 141 | self.open_ai_deployment = open_ai_deployment
|
141 | 142 | self.credential = credential
|
142 |
| - self.cached_token: Optional[AccessToken] = None |
143 | 143 |
|
144 | 144 | async def create_client(self) -> AsyncOpenAI:
|
| 145 | + class AuthArgs(TypedDict, total=False): |
| 146 | + api_key: str |
| 147 | + azure_ad_token_provider: Callable[[], Union[str, Awaitable[str]]] |
| 148 | + |
| 149 | + auth_args = AuthArgs() |
| 150 | + if isinstance(self.credential, AzureKeyCredential): |
| 151 | + auth_args["api_key"] = self.credential.key |
| 152 | + elif isinstance(self.credential, AsyncTokenCredential): |
| 153 | + auth_args["azure_ad_token_provider"] = get_bearer_token_provider( |
| 154 | + self.credential, "https://cognitiveservices.azure.com/.default" |
| 155 | + ) |
| 156 | + else: |
| 157 | + raise TypeError("Invalid credential type") |
| 158 | + |
145 | 159 | return AsyncAzureOpenAI(
|
146 | 160 | azure_endpoint=f"https://{self.open_ai_service}.openai.azure.com",
|
147 | 161 | azure_deployment=self.open_ai_deployment,
|
148 |
| - api_key=await self.wrap_credential(), |
149 | 162 | api_version="2023-05-15",
|
| 163 | + **auth_args, |
150 | 164 | )
|
151 | 165 |
|
152 |
| - async def wrap_credential(self) -> str: |
153 |
| - if isinstance(self.credential, AzureKeyCredential): |
154 |
| - return self.credential.key |
155 |
| - |
156 |
| - if isinstance(self.credential, AsyncTokenCredential): |
157 |
| - if not self.cached_token or self.cached_token.expires_on <= time.time(): |
158 |
| - self.cached_token = await self.credential.get_token("https://cognitiveservices.azure.com/.default") |
159 |
| - |
160 |
| - return self.cached_token.token |
161 |
| - |
162 |
| - raise TypeError("Invalid credential type") |
163 |
| - |
164 | 166 |
|
165 | 167 | class OpenAIEmbeddingService(OpenAIEmbeddings):
|
166 | 168 | """
|
|
0 commit comments