Skip to content

Commit

Permalink
Simplify Jinja Template undefined behavior
Browse files Browse the repository at this point in the history
Only deal with setting `undefined=StrictUndefined` in a single place
for our Jinja Templates by using a Jinja Environment across all
Template instances. This is Jinja's preferred way to deal with
Templates anyway, as opposed to to automatic implicit Environment that
gets created the first time you construct a bare Template.

This also simplifies the validation logic to not deal with `ChainMap`
or any of the `__missing__` Dict stuff, and instead let Jinja deal
with that entirely for us.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Dec 5, 2024
1 parent 7abde97 commit f6b9ffe
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
10 changes: 4 additions & 6 deletions src/instructlab/sdg/blocks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# Standard
from abc import ABC
from collections import ChainMap
from typing import Any, Dict, Union
import logging
import os.path
Expand Down Expand Up @@ -39,15 +38,14 @@ def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bo
True if the input data is valid (i.e., no missing variables), False otherwise.
"""

class Default(dict):
def __missing__(self, key: str) -> None:
raise KeyError(key)

try:
# Try rendering the template with the input_dict
prompt_template.render(ChainMap(input_dict, Default()))
prompt_template.render(input_dict)
return True
except UndefinedError as e:
# Jinja throws an UndefinedError for any undefnined template variables,
# assuming the prompt_template was created using StrictUndefined. This
# is the case for anything using PromptRegistry.template_from_string.
logger.error(f"Missing key: {e}")
return False

Expand Down
5 changes: 2 additions & 3 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# Third Party
from datasets import Dataset
from jinja2 import StrictUndefined, Template
from tqdm import tqdm
import httpx
import openai
Expand Down Expand Up @@ -60,7 +59,7 @@ def server_supports_batched(client, model_id: str) -> bool:
def template_from_struct_and_config(struct, config):
# replace None with empty strings
filtered_config = {k: (v if v is not None else "") for k, v in config.items()}
return Template(struct.format(**filtered_config), undefined=StrictUndefined)
return PromptRegistry.template_from_string(struct.format(**filtered_config))


# This is part of the public API.
Expand Down Expand Up @@ -163,7 +162,7 @@ def _format_prompt(self, sample: Dict) -> str:
if self.model_prompt is None:
model_prompt = PromptRegistry.get_template(self.ctx.model_family)
elif self.model_prompt:
model_prompt = Template(self.model_prompt, undefined=StrictUndefined)
model_prompt = PromptRegistry.template_from_string(self.model_prompt)
else:
# Our model prompt is an empty string, which we'll render
# verbatim without wrapping in the messages format
Expand Down
18 changes: 16 additions & 2 deletions src/instructlab/sdg/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

# Third Party
from jinja2 import StrictUndefined, Template
from jinja2 import Environment, StrictUndefined, Template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,6 +46,7 @@ class PromptRegistry:
"""Registry for managing Jinja2 prompt templates."""

_registry: Dict[str, Template] = {}
_template_env: Environment = Environment(undefined=StrictUndefined)

@classmethod
def register(cls, *names: str):
Expand All @@ -60,7 +61,7 @@ def register(cls, *names: str):

def decorator(func):
template_str = func()
template = Template(template_str, undefined=StrictUndefined)
template = cls.template_from_string(template_str)
for name in names:
cls._registry[name] = template
logger.debug(f"Registered prompt template '{name}'")
Expand Down Expand Up @@ -91,3 +92,16 @@ def get_registry(cls):
Dictionary of registered block names and classes.
"""
return cls._registry

@classmethod
def template_from_string(cls, template_str):
"""
Create a Jinja Template using our Environment from the source string
Args:
template_str: The template source, as a string-like thing
Returns:
Jinja Template
"""
return cls._template_env.from_string(template_str)

0 comments on commit f6b9ffe

Please sign in to comment.