diff --git a/langkit/__init__.py b/langkit/__init__.py index cb60145b..3778e346 100644 --- a/langkit/__init__.py +++ b/langkit/__init__.py @@ -22,6 +22,14 @@ class LangKitConfig: "support", ] ) + nlp_scores: list = field( + default_factory=lambda: [ + "bleu", + "rouge", + "meteor", + ] + ) + reference_corpus: str = "" def package_version(package: str = __package__) -> str: diff --git a/langkit/nlp_scores.py b/langkit/nlp_scores.py new file mode 100644 index 00000000..c4ba6c81 --- /dev/null +++ b/langkit/nlp_scores.py @@ -0,0 +1,66 @@ +from whylogs.experimental.core.metrics.udf_metric import register_metric_udf +import evaluate +from . import LangKitConfig +from logging import getLogger + +lang_config = LangKitConfig() +_corpus = lang_config.reference_corpus +_scores = lang_config.nlp_scores +_rouge_type = "rouge1" + +diagnostic_logger = getLogger(__name__) + + +def register_score_udfs(): + if _corpus: + for score in _scores: + if "bleu" in score: + bleu = evaluate.load("bleu") + + @register_metric_udf(col_name=lang_config.response_column) + def bleu_score(text: str) -> float: + return bleu.compute(predictions=[text], references=[_corpus])[ + "bleu" + ] + + if "rouge" in score: + rouge = evaluate.load("rouge") + + @register_metric_udf(col_name=lang_config.response_column) + def rouge_score(text: str) -> float: + return rouge.compute( + predictions=[text], + references=[_corpus], + rouge_types=[_rouge_type], + )[_rouge_type] + + if "meteor" in score: + meteor = evaluate.load("meteor") + + @register_metric_udf(col_name=lang_config.response_column) + def meteor_score(text: str) -> float: + return meteor.compute(predictions=[text], references=[_corpus])[ + "meteor" + ] + + else: + diagnostic_logger.warning( + "No reference corpus provided for NLP scores. Skipping NLP scores." + ) + + +def init(corpus=None, scores=[], rouge_type=None): + global _corpus + global _scores + global _rouge_type + if corpus: + _corpus = corpus + if scores: + _scores = scores + if rouge_type: + _rouge_type = rouge_type + + register_score_udfs() + + +init() diff --git a/langkit/tests/__init__.py b/langkit/tests/__init__.py index cb60145b..3778e346 100644 --- a/langkit/tests/__init__.py +++ b/langkit/tests/__init__.py @@ -22,6 +22,14 @@ class LangKitConfig: "support", ] ) + nlp_scores: list = field( + default_factory=lambda: [ + "bleu", + "rouge", + "meteor", + ] + ) + reference_corpus: str = "" def package_version(package: str = __package__) -> str: diff --git a/langkit/tests/test_nlp_scores.py b/langkit/tests/test_nlp_scores.py new file mode 100644 index 00000000..a093be20 --- /dev/null +++ b/langkit/tests/test_nlp_scores.py @@ -0,0 +1,23 @@ +import whylogs as why +import pytest + + +@pytest.mark.load +def test_bleu_score(): + from langkit import nlp_scores # noqa + from whylogs.experimental.core.udf_schema import udf_schema + + nlp_scores.init( + scores=["bleu"], corpus="The quick brown fox jumps over the lazy dog" + ) + text_schema = udf_schema() + profile = why.log( + {"response": "The quick dog jumps over the lazy brown fox"}, schema=text_schema + ).profile() + max_score = ( + profile.view() + .get_column("response") + .get_metrics()[-1] + .to_summary_dict()["bleu_score:distribution/max"] + ) + assert round(max_score, 2) == 0.42 diff --git a/poetry.lock b/poetry.lock index c3af9520..21654dd0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "aiohttp" @@ -587,6 +587,44 @@ files = [ [package.extras] graph = ["objgraph (>=1.7.2)"] + +[[package]] +name = "evaluate" +version = "0.4.0" +description = "HuggingFace community-driven open-source library of evaluation" +category = "main" +optional = true +python-versions = ">=3.7.0" +files = [ + {file = "evaluate-0.4.0-py3-none-any.whl", hash = "sha256:4b528de0f270cdfb077ca4877035dc17584d2c4b1cbc3fdd46afc3942ed557fd"}, + {file = "evaluate-0.4.0.tar.gz", hash = "sha256:bd6a59879be9ae13b681684e56ae3e6ea657073c4413b30335e9efa9856e4f44"}, +] + +[package.dependencies] +datasets = ">=2.0.0" +dill = "*" +fsspec = {version = ">=2021.05.0", extras = ["http"]} +huggingface-hub = ">=0.7.0" +multiprocess = "*" +numpy = ">=1.17" +packaging = "*" +pandas = "*" +requests = ">=2.19.0" +responses = "<0.19" +tqdm = ">=4.62.1" +xxhash = "*" + +[package.extras] +dev = ["Werkzeug (>=1.0.1)", "absl-py", "bert-score (>=0.3.6)", "black (>=22.0,<23.0)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "pyyaml (>=5.3.1)", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"] +docs = ["s3fs"] +evaluator = ["scipy (>=1.7.1)", "transformers"] +quality = ["black (>=22.0,<23.0)", "flake8 (>=3.8.3)", "isort (>=5.0.0)", "pyyaml (>=5.3.1)"] +template = ["cookiecutter", "gradio (>=3.0.0)"] +tensorflow = ["tensorflow (>=2.2.0,!=2.6.0,!=2.6.1)"] +tensorflow-gpu = ["tensorflow-gpu (>=2.2.0,!=2.6.0,!=2.6.1)"] +tests = ["Werkzeug (>=1.0.1)", "absl-py", "bert-score (>=0.3.6)", "cer (>=1.2.0)", "charcut (>=1.1.1)", "jiwer", "mauve-text", "nltk", "pytest", "pytest-datadir", "pytest-xdist", "requests-file (>=1.5.1)", "rouge-score (>=0.1.2)", "sacrebleu", "sacremoses", "scikit-learn", "scipy", "sentencepiece", "seqeval", "six (>=1.15.0,<1.16.0)", "tensorflow (>=2.3,!=2.6.0,!=2.6.1,<=2.10)", "texttable (>=1.6.3)", "tldextract (>=3.1.0)", "toml (>=0.10.1)", "torch", "transformers", "trectools", "unidecode (>=1.3.4)"] +torch = ["torch"] + [[package]] name = "distlib" version = "0.3.7" @@ -1603,6 +1641,7 @@ files = [ {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:3b08d4cc24f471b2c8ca24ec060abf4bebc6b144cb89cba638c720546b1cf538"}, {file = "Pillow-10.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d737a602fbd82afd892ca746392401b634e278cb65d55c4b7a8f48e9ef8d008d"}, {file = "Pillow-10.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:3a82c40d706d9aa9734289740ce26460a11aeec2d9c79b7af87bb35f0073c12f"}, + {file = "Pillow-10.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:bc2ec7c7b5d66b8ec9ce9f720dbb5fa4bace0f545acd34870eff4a369b44bf37"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_10_10_x86_64.whl", hash = "sha256:d80cf684b541685fccdd84c485b31ce73fc5c9b5d7523bf1394ce134a60c6883"}, {file = "Pillow-10.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76de421f9c326da8f43d690110f0e79fe3ad1e54be811545d7d91898b4c8493e"}, {file = "Pillow-10.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81ff539a12457809666fef6624684c008e00ff6bf455b4b89fd00a140eecd640"}, @@ -1612,6 +1651,7 @@ files = [ {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d50b6aec14bc737742ca96e85d6d0a5f9bfbded018264b3b70ff9d8c33485551"}, {file = "Pillow-10.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:00e65f5e822decd501e374b0650146063fbb30a7264b4d2744bdd7b913e0cab5"}, {file = "Pillow-10.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:f31f9fdbfecb042d046f9d91270a0ba28368a723302786c0009ee9b9f1f60199"}, + {file = "Pillow-10.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:1ce91b6ec08d866b14413d3f0bbdea7e24dfdc8e59f562bb77bc3fe60b6144ca"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:349930d6e9c685c089284b013478d6f76e3a534e36ddfa912cde493f235372f3"}, {file = "Pillow-10.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a684105f7c32488f7153905a4e3015a3b6c7182e106fe3c37fbb5ef3e6994c3"}, {file = "Pillow-10.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4f69b3700201b80bb82c3a97d5e9254084f6dd5fb5b16fc1a7b974260f89f43"}, @@ -2227,6 +2267,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "responses" +version = "0.18.0" +description = "A utility library for mocking out the `requests` Python library." +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "responses-0.18.0-py3-none-any.whl", hash = "sha256:15c63ad16de13ee8e7182d99c9334f64fd81f1ee79f90748d527c28f7ca9dd51"}, + {file = "responses-0.18.0.tar.gz", hash = "sha256:380cad4c1c1dc942e5e8a8eaae0b4d4edf708f4f010db8b7bcfafad1fcd254ff"}, +] + +[package.dependencies] +requests = ">=2.0,<3.0" +urllib3 = ">=1.25.10" + +[package.extras] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=4.6)", "pytest-cov", "pytest-localserver", "types-mock", "types-requests"] + [[package]] name = "safetensors" version = "0.3.1" @@ -3231,7 +3290,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] -all = ["datasets", "nltk", "openai", "sentence-transformers", "torch"] +all = ["torch", "datasets", "openai", "nltk", "sentence-transformers"] [metadata] lock-version = "2.0" diff --git a/pyproject.toml b/pyproject.toml index 2d0533c8..db05589d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ datasets = {version ="^2.12.0", optional = true} openai = {version ="^0.27.6", optional = true} nltk = {version ="^3.8.1", optional = true} sentence-transformers = {version ="^2.2.2", optional = true} +evaluate = {version = "^0.4.0", optional = true} [tool.poetry.group.dev.dependencies] @@ -40,6 +41,7 @@ all = [ "openai", "nltk", "sentence-transformers", + "evaluate", ]