Skip to content

Commit

Permalink
Use instructlab-schema package to parse qna.yaml files
Browse files Browse the repository at this point in the history
We remove code duplicated with instructlab to use the shared code
from instructlab-schema package.

A test case is also fixed which incorrectly passed a config string when
the name of a config file is the proper argument.

Signed-off-by: BJ Hargrave <[email protected]>
  • Loading branch information
bjhargrave committed Aug 2, 2024
1 parent 0bdedbc commit 9a1a748
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 244 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,8 @@ int-import-graph=
known-standard-library=

# Force import order to recognize a module as part of a third party library.
known-third-party=enchant
known-third-party=enchant,
instructlab.schema,

# Couples of modules and preferred modules, separated by a comma.
preferred-modules=
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
click>=8.1.7,<9.0.0
httpx>=0.25.0,<1.0.0
instructlab-schema>=0.4.0
langchain-text-splitters
openai>=1.13.3,<2.0.0
platformdirs>=4.2
Expand All @@ -9,4 +10,3 @@ platformdirs>=4.2
# do not use 8.4.0 due to a bug in the library
# https://github.com/instructlab/instructlab/issues/1389
tenacity>=8.3.0,!=8.4.0
instructlab-schema>=0.3.1
294 changes: 55 additions & 239 deletions src/instructlab/sdg/utils/taxonomy.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,52 @@
# SPDX-License-Identifier: Apache-2.0

# Standard
from functools import cache
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Dict, List, Union
import glob
import json
import logging
import os
import re
import subprocess
import tempfile

# Third Party
from instructlab.schema.taxonomy import DEFAULT_TAXONOMY_FOLDERS as TAXONOMY_FOLDERS
from instructlab.schema.taxonomy import (
TaxonomyMessageFormat,
TaxonomyParser,
TaxonomyReadingException,
)
import git
import gitdb
import yaml

# First Party
from instructlab.sdg.utils import chunking
# Local
from . import chunking

logger = logging.getLogger(__name__)

MIN_KNOWLEDGE_VERSION = 3

DEFAULT_YAML_RULES = """\
extends: relaxed
rules:
line-length:
max: 120
"""


# This is part of the public API.
class TaxonomyReadingException(Exception):
"""An exception raised during reading of the taxonomy."""


TAXONOMY_FOLDERS: List[str] = ["compositional_skills", "knowledge"]
"""Taxonomy folders which are also the schema names"""


def _istaxonomyfile(fn):
def _is_taxonomy_file(fn: str) -> bool:
path = Path(fn)
if path.suffix == ".yaml" and path.parts[0] in TAXONOMY_FOLDERS:
if path.parts[0] not in TAXONOMY_FOLDERS:
return False
if path.name == "qna.yaml":
return True
if path.name.casefold() in {"qna.yml", "qna.yaml"}:
# warning for incorrect extension or case variants
logger.warning(
"Found a '%s' file: %s: taxonomy files must be named 'qna.yaml'. File will not be checked.",
path.name,
path,
)
return False


def _get_taxonomy_diff(repo="taxonomy", base="origin/main"):
repo = git.Repo(repo)
untracked_files = [u for u in repo.untracked_files if _istaxonomyfile(u)]
def _get_taxonomy_diff(
repo_path: str | Path = "taxonomy", base: str = "origin/main"
) -> list[str]:
repo = git.Repo(repo_path)
untracked_files = [u for u in repo.untracked_files if _is_taxonomy_file(u)]

branches = [b.name for b in repo.branches]

Expand Down Expand Up @@ -90,7 +85,7 @@ def _get_taxonomy_diff(repo="taxonomy", base="origin/main"):
modified_files = [
d.b_path
for d in head_commit.diff(None)
if not d.deleted_file and _istaxonomyfile(d.b_path)
if not d.deleted_file and _is_taxonomy_file(d.b_path)
]

updated_taxonomy_files = list(set(untracked_files + modified_files))
Expand Down Expand Up @@ -135,214 +130,25 @@ def _get_documents(
raise e


@cache
def _load_schema(path: "importlib.resources.abc.Traversable") -> "referencing.Resource":
"""Load the schema from the path into a Resource object.
Args:
path (Traversable): Path to the schema to be loaded.
Raises:
NoSuchResource: If the resource cannot be loaded.
Returns:
Resource: A Resource containing the requested schema.
"""
# pylint: disable=C0415
# Third Party
from referencing import Resource
from referencing.exceptions import NoSuchResource
from referencing.jsonschema import DRAFT202012

try:
contents = json.loads(path.read_text(encoding="utf-8"))
resource = Resource.from_contents(
contents=contents, default_specification=DRAFT202012
)
except Exception as e:
raise NoSuchResource(ref=str(path)) from e
return resource


def _validate_yaml(contents: Mapping[str, Any], taxonomy_path: Path) -> int:
"""Validate the parsed yaml document using the taxonomy path to
determine the proper schema.
Args:
contents (Mapping): The parsed yaml document to validate against the schema.
taxonomy_path (Path): Relative path of the taxonomy yaml document where the
first element is the schema to use.
Returns:
int: The number of errors found during validation.
Messages for each error have been logged.
"""
# pylint: disable=C0415
# Standard
from importlib import resources

# Third Party
from jsonschema.protocols import Validator
from jsonschema.validators import validator_for
from referencing import Registry, Resource
from referencing.exceptions import NoSuchResource
from referencing.typing import URI

errors = 0
version = _get_version(contents)
schemas_path = resources.files("instructlab.schema").joinpath(f"v{version}")

def retrieve(uri: URI) -> Resource:
path = schemas_path.joinpath(uri)
return _load_schema(path)

schema_name = taxonomy_path.parts[0]
if schema_name not in TAXONOMY_FOLDERS:
schema_name = "knowledge" if "document" in contents else "compositional_skills"
logger.info(
f"Cannot determine schema name from path {taxonomy_path}. Using {schema_name} schema."
)

if schema_name == "knowledge" and version < MIN_KNOWLEDGE_VERSION:
logger.error(
f"Version {version} is not supported for knowledge taxonomy. Minimum supported version is {MIN_KNOWLEDGE_VERSION}."
)
errors += 1
return errors

try:
schema_resource = retrieve(f"{schema_name}.json")
schema = schema_resource.contents
validator_cls = validator_for(schema)
validator: Validator = validator_cls(
schema, registry=Registry(retrieve=retrieve)
)

for validation_error in validator.iter_errors(contents):
errors += 1
yaml_path = validation_error.json_path[1:]
if not yaml_path:
yaml_path = "."
if validation_error.validator == "minItems":
# Special handling for minItems which can have a long message for seed_examples
message = (
f"Value must have at least {validation_error.validator_value} items"
)
else:
message = validation_error.message[-200:]
logger.error(
f"Validation error in {taxonomy_path}: [{yaml_path}] {message}"
)
except NoSuchResource as e:
cause = e.__cause__ if e.__cause__ is not None else e
errors += 1
logger.error(f"Cannot load schema file {e.ref}. {cause}")

return errors

# pylint: disable=broad-exception-caught
def _read_taxonomy_file(file_path: str | Path, yamllint_config: str | None = None):
seed_instruction_data = []

def _get_version(contents: Mapping) -> int:
version = contents.get("version", 1)
if not isinstance(version, int):
# schema validation will complain about the type
try:
version = int(version)
except ValueError:
version = 1 # fallback to version 1
return version
parser = TaxonomyParser(
schema_version=0, # Use version value in yaml
message_format=TaxonomyMessageFormat.LOGGING, # Report warnings and errors to the logger
yamllint_config=yamllint_config,
yamllint_strict=True, # Report yamllint warnings as errors
)
taxonomy = parser.parse(file_path)

if taxonomy.warnings or taxonomy.errors:
return seed_instruction_data, taxonomy.warnings, taxonomy.errors

# pylint: disable=broad-exception-caught
def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None):
seed_instruction_data = []
warnings = 0
errors = 0
file_path = Path(file_path).resolve()
# file should end with ".yaml" explicitly
if file_path.suffix != ".yaml":
logger.warning(
f"Skipping {file_path}! Use lowercase '.yaml' extension instead."
)
warnings += 1
return None, warnings, errors
for i in range(len(file_path.parts) - 1, -1, -1):
if file_path.parts[i] in TAXONOMY_FOLDERS:
taxonomy_path = Path(*file_path.parts[i:])
break
else:
taxonomy_path = file_path
# read file if extension is correct
try:
with open(file_path, "r", encoding="utf-8") as file:
contents = yaml.safe_load(file)
if not contents:
logger.warning(f"Skipping {file_path} because it is empty!")
warnings += 1
return None, warnings, errors
if not isinstance(contents, Mapping):
logger.error(
f"{file_path} is not valid. The top-level element is not an object with key-value pairs."
)
errors += 1
return None, warnings, errors

# do general YAML linting if specified
version = _get_version(contents)
if version > 1: # no linting for version 1 yaml
if yaml_rules is not None:
is_file = os.path.isfile(yaml_rules)
if is_file:
logger.debug(f"Using YAML rules from {yaml_rules}")
yamllint_cmd = [
"yamllint",
"-f",
"parsable",
"-c",
yaml_rules,
file_path,
"-s",
]
else:
logger.debug(f"Cannot find {yaml_rules}. Using default rules.")
yamllint_cmd = [
"yamllint",
"-f",
"parsable",
"-d",
DEFAULT_YAML_RULES,
file_path,
"-s",
]
else:
yamllint_cmd = [
"yamllint",
"-f",
"parsable",
"-d",
DEFAULT_YAML_RULES,
file_path,
"-s",
]
try:
subprocess.check_output(yamllint_cmd, text=True)
except subprocess.SubprocessError as e:
lint_messages = [f"Problems found in file {file_path}"]
parsed_output = e.output.splitlines()
for p in parsed_output:
errors += 1
delim = str(file_path) + ":"
parsed_p = p.split(delim)[1]
lint_messages.append(parsed_p)
logger.error("\n".join(lint_messages))
return None, warnings, errors

validation_errors = _validate_yaml(contents, taxonomy_path)
if validation_errors:
errors += validation_errors
return None, warnings, errors

# get seed instruction data
tax_path = "->".join(taxonomy_path.parent.parts)
tax_path = "->".join(taxonomy.path.parent.parts)
contents = taxonomy.contents
task_description = contents.get("task_description", None)
domain = contents.get("domain")
documents = contents.get("document")
Expand Down Expand Up @@ -380,18 +186,28 @@ def _read_taxonomy_file(file_path: str, yaml_rules: Optional[str] = None):
}
)
except Exception as e:
errors += 1
raise TaxonomyReadingException(f"Exception {e} raised in {file_path}") from e

return seed_instruction_data, warnings, errors
return seed_instruction_data, 0, 0


def read_taxonomy(
taxonomy: str | Path, taxonomy_base: str, yaml_rules: str | None = None
):
yamllint_config = None # If no custom rules file, use default config
if yaml_rules is not None: # user attempted to pass custom rules file
yaml_rules_path = Path(yaml_rules)
if yaml_rules_path.is_file(): # file was found, use specified config
logger.debug("Using YAML rules from %s", yaml_rules)
yamllint_config = yaml_rules_path.read_text(encoding="utf-8")
else:
logger.debug("Cannot find %s. Using default rules.", yaml_rules)

def read_taxonomy(taxonomy, taxonomy_base, yaml_rules):
seed_instruction_data = []
is_file = os.path.isfile(taxonomy)
if is_file: # taxonomy is file
seed_instruction_data, warnings, errors = _read_taxonomy_file(
taxonomy, yaml_rules
taxonomy, yamllint_config
)
if warnings:
logger.warning(
Expand All @@ -410,7 +226,7 @@ def read_taxonomy(taxonomy, taxonomy_base, yaml_rules):
logger.debug(f"* {e}")
for f in updated_taxonomy_files:
file_path = os.path.join(taxonomy, f)
data, warnings, errors = _read_taxonomy_file(file_path, yaml_rules)
data, warnings, errors = _read_taxonomy_file(file_path, yamllint_config)
total_warnings += warnings
total_errors += errors
if data:
Expand Down
Loading

0 comments on commit 9a1a748

Please sign in to comment.