diff --git a/src/instructlab/sdg/__init__.py b/src/instructlab/sdg/__init__.py index 6f091424..490df8e4 100644 --- a/src/instructlab/sdg/__init__.py +++ b/src/instructlab/sdg/__init__.py @@ -3,6 +3,7 @@ # NOTE: This package imports Torch and other heavy packages. __all__ = ( "Block", + "BlockConfigParserError", "BlockRegistry", "CombineColumnsBlock", "ConditionalLLMBlock", @@ -31,7 +32,7 @@ ) # Local -from .blocks.block import Block +from .blocks.block import Block, BlockConfigParserError from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError from .blocks.iterblock import IterBlock from .blocks.llmblock import ( diff --git a/src/instructlab/sdg/blocks/block.py b/src/instructlab/sdg/blocks/block.py index 3bceca65..bc955596 100644 --- a/src/instructlab/sdg/blocks/block.py +++ b/src/instructlab/sdg/blocks/block.py @@ -68,3 +68,8 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]: ) with open(config_path, "r", encoding="utf-8") as config_file: return yaml.safe_load(config_file) + + +# This is part of the public API. +class BlockConfigParserError(Exception): + """An exception raised while parsing a block's configuration.""" diff --git a/src/instructlab/sdg/blocks/llmblock.py b/src/instructlab/sdg/blocks/llmblock.py index 92e7402e..89d9a27e 100644 --- a/src/instructlab/sdg/blocks/llmblock.py +++ b/src/instructlab/sdg/blocks/llmblock.py @@ -15,7 +15,7 @@ # Import prompts to register default chat templates from .. import prompts as default_prompts # pylint: disable=unused-import from ..registry import BlockRegistry, PromptRegistry -from .block import Block +from .block import Block, BlockConfigParserError logger = logging.getLogger(__name__) @@ -290,13 +290,15 @@ def __init__( parser_kwargs={}, batch_kwargs={}, ) -> None: - assert ( - config_paths - ), "ConditionalLLMBlock config_paths requires at least one entry" + if not config_paths: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} requires at least one entry" + ) for config_path in config_paths: - assert ( - len(config_path) == 2 - ), "ConditionalLLMBlock config_paths each entry should be a list of config path and selector column names" + if len(config_path) != 2: + raise BlockConfigParserError( + f"ConditionalLLMBlock config_paths of block {block_name} should be a list of config path and selector column names" + ) super().__init__( ctx, pipe, diff --git a/tests/test_llmblock.py b/tests/test_llmblock.py index 15519fb7..f9d78177 100644 --- a/tests/test_llmblock.py +++ b/tests/test_llmblock.py @@ -10,9 +10,11 @@ from datasets import Dataset, Features, Value from httpx import URL from openai import InternalServerError, NotFoundError +import pytest # First Party from src.instructlab.sdg import ( + BlockConfigParserError, ConditionalLLMBlock, LLMBlock, LLMLogProbBlock, @@ -258,6 +260,29 @@ def setUp(self): self.mock_ctx.model_id = "test_model" self.mock_pipe = MagicMock() + def test_invalid_config_paths(self, _mock_load_config): + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[], + output_cols=[], + selector_column_name="selector", + ) + assert "at least one entry" in str(exc.value) + + with pytest.raises(BlockConfigParserError) as exc: + ConditionalLLMBlock( + ctx=self.mock_ctx, + pipe=self.mock_pipe, + block_name="gen_knowledge", + config_paths=[["foo"]], + output_cols=[], + selector_column_name="selector", + ) + assert "config path and selector" in str(exc.value) + def test_validate(self, mock_load_config): mock_load_config.return_value = { "system": "{{var1}} {{var2}}",