Skip to content

Commit

Permalink
Validate blocks by raising BlockConfigParserError instead of asserts
Browse files Browse the repository at this point in the history
Asserts outside of tests should only be used for programming errors in
our own code and not to validate user-facing things.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Dec 10, 2024
1 parent 42fb71d commit db3a1ad
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# NOTE: This package imports Torch and other heavy packages.
__all__ = (
"Block",
"BlockConfigParserError",
"BlockRegistry",
"CombineColumnsBlock",
"ConditionalLLMBlock",
Expand Down Expand Up @@ -31,7 +32,7 @@
)

# Local
from .blocks.block import Block
from .blocks.block import Block, BlockConfigParserError
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.iterblock import IterBlock
from .blocks.llmblock import (
Expand Down
5 changes: 5 additions & 0 deletions src/instructlab/sdg/blocks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,8 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
)
with open(config_path, "r", encoding="utf-8") as config_file:
return yaml.safe_load(config_file)


# This is part of the public API.
class BlockConfigParserError(Exception):
"""An exception raised while parsing a block's configuration."""
16 changes: 9 additions & 7 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Import prompts to register default chat templates
from .. import prompts as default_prompts # pylint: disable=unused-import
from ..registry import BlockRegistry, PromptRegistry
from .block import Block
from .block import Block, BlockConfigParserError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -290,13 +290,15 @@ def __init__(
parser_kwargs={},
batch_kwargs={},
) -> None:
assert (
config_paths
), "ConditionalLLMBlock config_paths requires at least one entry"
if not config_paths:
raise BlockConfigParserError(
f"ConditionalLLMBlock config_paths of block {block_name} requires at least one entry"
)
for config_path in config_paths:
assert (
len(config_path) == 2
), "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names"
if len(config_path) != 2:
raise BlockConfigParserError(
f"ConditionalLLMBlock config_paths of block {block_name} should be a list of config path and selector column names"
)
super().__init__(
ctx,
pipe,
Expand Down
25 changes: 25 additions & 0 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from datasets import Dataset, Features, Value
from httpx import URL
from openai import InternalServerError, NotFoundError
import pytest

# First Party
from src.instructlab.sdg import (
BlockConfigParserError,
ConditionalLLMBlock,
LLMBlock,
LLMLogProbBlock,
Expand Down Expand Up @@ -258,6 +260,29 @@ def setUp(self):
self.mock_ctx.model_id = "test_model"
self.mock_pipe = MagicMock()

def test_invalid_config_paths(self, _mock_load_config):
with pytest.raises(BlockConfigParserError) as exc:
ConditionalLLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
config_paths=[],
output_cols=[],
selector_column_name="selector",
)
assert "at least one entry" in str(exc.value)

with pytest.raises(BlockConfigParserError) as exc:
ConditionalLLMBlock(
ctx=self.mock_ctx,
pipe=self.mock_pipe,
block_name="gen_knowledge",
config_paths=[["foo"]],
output_cols=[],
selector_column_name="selector",
)
assert "config path and selector" in str(exc.value)

def test_validate(self, mock_load_config):
mock_load_config.return_value = {
"system": "{{var1}} {{var2}}",
Expand Down

0 comments on commit db3a1ad

Please sign in to comment.