Skip to content

Commit

Permalink
Move Block._validate to llmblock
Browse files Browse the repository at this point in the history
The custom "validate" function in ConditionalLLMBlock wasn't
being called and renaming it "_validate" is problematic as its
static. Instead move it to LLMBlock and change it to a instance
method as only LLMBlock uses it.

Closes #186

Co-authored-by: abhi1092 <[email protected]>

Signed-off-by: Derek Higgins <[email protected]>
  • Loading branch information
derekhiggins committed Jul 23, 2024
1 parent e40f0f4 commit fef8f25
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
21 changes: 0 additions & 21 deletions src/instructlab/sdg/block.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from abc import ABC
from collections import ChainMap
from typing import Any, Dict, Union
import os.path

Expand All @@ -21,26 +20,6 @@ def __init__(self, ctx, pipe, block_name: str) -> None:
self.pipe = pipe
self.block_name = block_name

@staticmethod
def _validate(prompt_template: str, input_dict: Dict[str, Any]) -> bool:
"""
Validate the input data for this block. This method should be implemented by subclasses
to define how the block validates its input data.
:return: True if the input data is valid, False otherwise.
"""

class Default(dict):
def __missing__(self, key: str) -> None:
raise KeyError(key)

try:
prompt_template.format_map(ChainMap(input_dict, Default()))
return True
except KeyError as e:
logger.error("Missing key: {}".format(e))
return False

def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
"""
Load the configuration file for this block.
Expand Down
22 changes: 21 additions & 1 deletion src/instructlab/sdg/llmblock.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# Standard
from collections import ChainMap
from typing import Any, Dict
import re

Expand Down Expand Up @@ -224,6 +225,25 @@ def generate(self, samples: Dataset) -> Dataset:

return Dataset.from_list(new_data)

def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
"""
Validate the input data for this block. This method should be implemented by subclasses
to define how the block validates its input data.
:return: True if the input data is valid, False otherwise.
"""

class Default(dict):
def __missing__(self, key: str) -> None:
raise KeyError(key)

try:
prompt_template.format_map(ChainMap(input_dict, Default()))
return True
except KeyError as e:
logger.error("Missing key: {}".format(e))
return False


# This is part of the public API.
class ConditionalLLMBlock(LLMBlock):
Expand Down Expand Up @@ -269,7 +289,7 @@ def _format_prompt(self, sample: Dict) -> str:

return self.prompt_template.format(**sample).strip()

def validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
def _validate(self, prompt_template: str, input_dict: Dict[str, Any]) -> bool:
if isinstance(prompt_template, dict):
prompt_template = prompt_template[input_dict[self.selector_column_name]]
return super()._validate(prompt_template, input_dict)

0 comments on commit fef8f25

Please sign in to comment.