diff --git a/llm_judge/common.py b/llm_judge/common.py index d4e8ded..789bf12 100644 --- a/llm_judge/common.py +++ b/llm_judge/common.py @@ -59,12 +59,16 @@ def judge(self, **kwargs): ] for _ in range(API_MAX_RETRY): try: - response = openai.ChatCompletion.create( - model=self.model, - messages=messages, - temperature=0, - max_tokens=2048, - ) + params = { + "messages": messages, + "temperature": 0, + "max_tokens": 2048, + } + if openai.api_type == "azure": + params["engine"] = self.model + else: + params["model"] = self.model + response = openai.ChatCompletion.create(**params) return response["choices"][0]["message"]["content"] except openai.error.OpenAIError as e: logger.warning(f"OpenAI API error: {e}")