Skip to content

Commit 6141ce9

Browse files
authored
fix(graders): fix typo in code_execution filename and imports (#54)
* fix(graders): fix typo in code_execution filename and imports - Rename code_excution.py to code_execution.py (fix typo) - Fix imports in code/__init__.py: open_judge -> openjudge - Remove commented-out import in base_grader.py * refactor(core): improve type hints, docstrings and error handling - Fix doctest examples to use proper subclass instantiation - Add type hints to concurrency.py run_with_concurrency_control - Support list templates in LLMGrader initialization - Fix DashScope formatter to handle empty lists correctly - Improve PromptTemplate error messages for missing languages - Use more specific exception types in utils modules - Update test assertion for new template error message * fix(graders): use relative imports and fix empty content handling
1 parent 2f6beee commit 6141ce9

File tree

14 files changed

+68
-47
lines changed

14 files changed

+68
-47
lines changed

openjudge/graders/base_grader.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
either scores or rankings.
77
"""
88

9-
# import inspect
109
from abc import ABC, abstractmethod
1110
from typing import Any, Dict
1211

@@ -60,7 +59,10 @@ def __init__(
6059
accessible to subclasses.
6160
6261
Example:
63-
>>> grader = BaseGrader(
62+
>>> class MyGrader(BaseGrader):
63+
... async def aevaluate(self, **kwargs):
64+
... pass
65+
>>> grader = MyGrader(
6466
... name="accuracy_grader",
6567
... mode=GraderMode.POINTWISE,
6668
... description="Evaluates answer accuracy"
@@ -189,7 +191,12 @@ def from_config(
189191
# Extract standard grader properties from a copy to avoid mutating the input config
190192
config_copy = dict(config)
191193
name = config_copy.pop("name", "")
192-
mode = config_copy.pop("mode", GraderMode.POINTWISE)
194+
mode_value = config_copy.pop("mode", GraderMode.POINTWISE)
195+
# Convert string to GraderMode if necessary
196+
if isinstance(mode_value, str):
197+
mode = GraderMode(mode_value)
198+
else:
199+
mode = mode_value
193200
description = config_copy.pop("description", "")
194201

195202
# Create and return new instance with remaining config items as kwargs

openjudge/graders/code/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
extensible evaluation mechanisms for AI-generated content.
99
"""
1010

11-
from open_judge.graders.code.code_excution import CodeExecutionGrader
12-
from open_judge.graders.code.code_style import CodeStyleGrader
13-
from open_judge.graders.code.patch_similarity import PatchSimilarityGrader
14-
from open_judge.graders.code.syntax_checker import SyntaxCheckGrader
11+
from .code_execution import CodeExecutionGrader
12+
from .code_style import CodeStyleGrader
13+
from .patch_similarity import PatchSimilarityGrader
14+
from .syntax_checker import SyntaxCheckGrader
1515

1616
__all__ = [
1717
"CodeExecutionGrader",

openjudge/graders/llm_grader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,13 @@ def __init__(
133133
)
134134
elif isinstance(template, PromptTemplate):
135135
self.template = template
136+
elif isinstance(template, list):
137+
# Support list of message dicts or ChatMessage objects
138+
self.template = PromptTemplate.from_prompt(template)
136139
elif isinstance(template, dict):
137140
self.template = PromptTemplate(**template)
138141
else:
139-
raise ValueError("Template must be a str, dict or PromptTemplate object")
142+
raise ValueError("Template must be a str, list, dict or PromptTemplate object")
140143

141144
# Initialize model
142145
if isinstance(model, dict):

openjudge/graders/multimodal/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
- Text-to-image generation quality
99
"""
1010

11-
from openjudge.graders.multimodal._internal import MLLMImage
12-
from openjudge.graders.multimodal.image_coherence import ImageCoherenceGrader
13-
from openjudge.graders.multimodal.image_helpfulness import ImageHelpfulnessGrader
14-
from openjudge.graders.multimodal.text_to_image import TextToImageGrader
11+
from ._internal import MLLMImage
12+
from .image_coherence import ImageCoherenceGrader
13+
from .image_helpfulness import ImageHelpfulnessGrader
14+
from .text_to_image import TextToImageGrader
1515

1616
__all__ = [
1717
# Graders

openjudge/graders/schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ class GraderRank(GraderResult):
142142

143143

144144
class GraderRankCallback(BaseModel):
145-
"""Callback for grader rank result, used for .
145+
"""Callback schema for LLM structured output in listwise grading.
146146
147-
Represents a ranking of items assigned by a grader along with a reason.
147+
Used as the structured_model parameter in LLMGrader for LISTWISE mode.
148+
The LLM returns this schema which is then converted to GraderRank.
148149
149150
Attributes:
150151
rank (List[int]): The ranking of items.

openjudge/models/base_chat_model.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class BaseChatModel(ABC):
2929
... async def achat(self, *args, **kwargs):
3030
... # Implementation here
3131
... pass
32-
>>> model = MyChatModel(model="qwen3-max", stream=False)
32+
>>> model = MyChatModel(model="qwen3-32b", stream=False)
3333
>>> print(model.model)
3434
qwen3-32b
3535
"""
@@ -52,11 +52,12 @@ def __init__(
5252
stream: Whether the model output is streaming or not.
5353
5454
Example:
55-
>>> model = BaseChatModel(model="qwen3-32b", stream=True)
55+
>>> class MyChatModel(BaseChatModel):
56+
... async def achat(self, *args, **kwargs):
57+
... pass
58+
>>> model = MyChatModel(model="qwen3-32b", stream=True)
5659
>>> print(model.model)
5760
qwen3-32b
58-
>>> print(model.stream)
59-
True
6061
"""
6162
self.model = model
6263
self.stream = stream
@@ -102,9 +103,11 @@ def _validate_tool_choice(
102103
ValueError: If tool_choice is invalid.
103104
104105
Example:
105-
>>> model = BaseChatModel(model="test", stream=False)
106+
>>> class MyChatModel(BaseChatModel):
107+
... async def achat(self, *args, **kwargs):
108+
... pass
109+
>>> model = MyChatModel(model="test", stream=False)
106110
>>> model._validate_tool_choice("auto", None) # Valid
107-
>>> # model._validate_tool_choice(123, None) # Would raise TypeError
108111
"""
109112
if not isinstance(tool_choice, str):
110113
raise TypeError(

openjudge/models/formatter/dashscope_formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _convert_content_to_openai(
8383
if isinstance(content, str):
8484
return content
8585

86-
# If content is a list, process each part
86+
# If content is a list, process each part (including empty list)
8787
if isinstance(content, list):
8888
openai_content = []
8989
for part in content:
@@ -143,7 +143,7 @@ def _convert_content_to_dashscope(
143143
if isinstance(content, str):
144144
return content
145145

146-
# If content is a list, process each part
146+
# If content is a list, process each part (including empty list)
147147
if isinstance(content, list):
148148
dashscope_content = []
149149
for part in content:

openjudge/models/schema/prompt_template.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,8 @@ def to_messages(
167167
List[ChatMessage]: The messages for the specified language.
168168
169169
Raises:
170-
AssertionError: If the specified language is not available in a
171-
multilingual template.
172-
ValueError: If messages format is invalid.
170+
ValueError: If the specified language is not available in a
171+
multilingual template, or if messages format is invalid.
173172
174173
Examples:
175174
>>> template = PromptTemplate(messages=[ChatMessage(role="user", content="Hello")])
@@ -182,16 +181,17 @@ def to_messages(
182181
[ChatMessage(role="user", content="Hello")]
183182
"""
184183
if isinstance(self.messages, list):
185-
messages = self.messages
186-
elif isinstance(self.messages, dict):
184+
return self.messages
185+
186+
if isinstance(self.messages, dict):
187187
if not language:
188188
language = LanguageEnum.EN
189-
assert language in self.messages
190-
messages = self.messages.get(language, [])
191-
else:
192-
raise ValueError("Invalid messages")
189+
if language not in self.messages:
190+
available = [lang.value for lang in self.messages.keys()]
191+
raise ValueError(f"Language '{language.value}' not found. Available: {available}")
192+
return self.messages[language]
193193

194-
return messages
194+
raise ValueError("Invalid messages format")
195195

196196
@classmethod
197197
def from_prompt(cls, prompt: Prompt) -> "PromptTemplate":
@@ -280,7 +280,7 @@ def format(
280280
messages = [message.format(**kwargs).to_dict() for message in messages]
281281
return messages
282282

283-
def get_prompt(self, language: LanguageEnum = None) -> Dict[str, List[Dict[str, str]]]:
283+
def get_prompt(self, language: LanguageEnum | None = None) -> Dict[str, List[Dict[str, str]]]:
284284
"""Return the core prompts (role, content) information of the messages,
285285
in a {language: list[{'role': txt, 'content': txt}]} dictionary.
286286
"""

openjudge/utils/concurrency.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
"""
77

88
import asyncio
9+
from typing import Awaitable, TypeVar
10+
11+
T = TypeVar("T")
912

1013

1114
class ConcurrencyManager:
@@ -61,15 +64,15 @@ def get_max_concurrency(self) -> int:
6164
"""
6265
return self._max_concurrency
6366

64-
async def run_with_concurrency_control(self, coro):
67+
async def run_with_concurrency_control(self, coro: Awaitable[T]) -> T:
6568
"""
6669
Run a coroutine with concurrency control.
6770
6871
Args:
69-
coro: The coroutine to run
72+
coro: The coroutine to run.
7073
7174
Returns:
72-
The result of the coroutine
75+
T: The result of the coroutine.
7376
"""
7477
async with self._semaphore:
7578
return await coro

0 commit comments

Comments
 (0)