Skip to content

Commit

Permalink
Add support for Azure OpenAI hosted models (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamie256 authored Nov 22, 2023
1 parent 61ca254 commit 5b9383e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
2 changes: 2 additions & 0 deletions langkit/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
send_prompt,
Conversation,
LLMInvocationParams,
OpenAIAzure,
OpenAIDavinci,
OpenAIGPT4,
OpenAIDefault,
Expand All @@ -13,6 +14,7 @@
send_prompt,
Conversation,
LLMInvocationParams,
OpenAIAzure,
OpenAIDavinci,
OpenAIDefault,
OpenAIGPT4,
Expand Down
38 changes: 37 additions & 1 deletion langkit/openai/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def completion(self, messages: List[Dict[str, str]]):
f"last message must exist and contain a content key but got {last_message}"
)
params = asdict(self)
openai.api_key = os.getenv("OPENAI_API_KEY")
text_completion_respone = openai.Completion.create(prompt=prompt, **params)
content = text_completion_respone.choices[0].text
response = type(
Expand Down Expand Up @@ -138,7 +139,7 @@ class OpenAIDefault(LLMInvocationParams):

def completion(self, messages: List[Dict[str, str]], **kwargs):
params = asdict(self)
openai.ChatCompletion.create
openai.api_key = os.getenv("OPENAI_API_KEY")
return openai.ChatCompletion.create(messages=messages, **params)

def copy(self) -> LLMInvocationParams:
Expand All @@ -151,6 +152,40 @@ def copy(self) -> LLMInvocationParams:
)


@dataclass
class OpenAIAzure(LLMInvocationParams):
temperature: float = field(default_factory=lambda: _llm_model_temperature)
max_tokens: int = field(default_factory=lambda: _llm_model_max_tokens)
frequency_penalty: float = field(
default_factory=lambda: _llm_model_frequency_penalty
)
presence_penalty: float = field(default_factory=lambda: _llm_model_presence_penalty)
engine: Optional[str] = None
api_type: Optional[str] = None
api_version: Optional[str] = None

def completion(self, messages: List[Dict[str, str]], **kwargs):
params = asdict(self)
openai.api_type = self.api_type or "azure"
openai.api_version = self.api_version or "2023-05-15"
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
if endpoint:
openai.api_base = endpoint
openai.api_key = os.getenv("AZURE_OPENAI_KEY")
return openai.ChatCompletion.create(messages=messages, **params)

def copy(self) -> LLMInvocationParams:
return OpenAIAzure(
engine=self.engine,
temperature=self.temperature,
max_tokens=self.max_tokens,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
api_type=self.api_type,
api_version=self.api_version,
)


@dataclass
class OpenAIGPT4(LLMInvocationParams):
model: str = field(default_factory=lambda: "gpt-4")
Expand All @@ -163,6 +198,7 @@ class OpenAIGPT4(LLMInvocationParams):

def completion(self, messages: List[Dict[str, str]], **kwargs):
params = asdict(self)
openai.api_key = os.getenv("OPENAI_API_KEY")
return openai.ChatCompletion.create(messages=messages, **params)

def copy(self) -> LLMInvocationParams:
Expand Down
1 change: 0 additions & 1 deletion langkit/response_hallucination.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def __init__(self, llm: LLMInvocationParams, num_samples, embeddings_encoder):
self.sample_generator_llm = sample_generator
consistency_checker_llm = llm.copy()
consistency_checker_llm.temperature = 0
consistency_checker_llm.max_tokens = 10
self.consistency_checker_llm = consistency_checker_llm
self.embeddings_encoder = embeddings_encoder

Expand Down

0 comments on commit 5b9383e

Please sign in to comment.