diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..2da9593 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,12 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +langextract = [ + "langextract>=0.1.0", +] + test = [ + "langextract>=0.1.0", # For LangExtractTool tests "pytest>=8.4.2", "pytest-asyncio>=1.2.0", ] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..21e8ab2 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -14,5 +14,6 @@ from . import memory from . import sessions +from . import tools from . import version __version__ = version.__version__ diff --git a/src/google/adk_community/tools/__init__.py b/src/google/adk_community/tools/__init__.py new file mode 100644 index 0000000..8be73aa --- /dev/null +++ b/src/google/adk_community/tools/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .langextract_tool import LangExtractTool +from .langextract_tool import LangExtractToolConfig + +__all__ = [ + 'LangExtractTool', + 'LangExtractToolConfig', +] diff --git a/src/google/adk_community/tools/langextract_tool.py b/src/google/adk_community/tools/langextract_tool.py new file mode 100644 index 0000000..7223a98 --- /dev/null +++ b/src/google/adk_community/tools/langextract_tool.py @@ -0,0 +1,209 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from dataclasses import field +import logging +from typing import Any +from typing import Optional + +from google.adk.tools import BaseTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from typing_extensions import override + +try: + import langextract as lx +except ImportError as e: + raise ImportError( + 'LangExtract tools require pip install langextract.' + ) from e + +logger = logging.getLogger(__name__) + + +class LangExtractTool(BaseTool): + """A tool that extracts structured information from text using LangExtract. + + This tool wraps the langextract library to enable LLM agents to extract + structured data (entities, attributes, relationships) from unstructured + text. The agent provides the text to extract from and a description of + what to extract; other parameters are pre-configured at construction time. + + Args: + name: The name of the tool. Defaults to 'langextract'. + description: The description of the tool shown to the LLM. + examples: Optional list of langextract ExampleData for few-shot + extraction guidance. + model_id: The model ID for langextract to use internally. + Defaults to 'gemini-2.5-flash'. + api_key: Optional API key for langextract. If None, uses the + LANGEXTRACT_API_KEY environment variable. + extraction_passes: Number of extraction passes. Defaults to 1. + max_workers: Maximum worker threads for langextract. Defaults to 1. + max_char_buffer: Maximum character buffer size for text chunking. + Defaults to 4000. + + Examples:: + + from google.adk_community.tools import LangExtractTool + import langextract as lx + + tool = LangExtractTool( + name='extract_entities', + description='Extract named entities from text.', + examples=[ + lx.data.ExampleData( + text='John is a software engineer at Google.', + extractions=[ + lx.data.Extraction( + extraction_class='person', + extraction_text='John', + attributes={ + 'role': 'software engineer', + 'company': 'Google', + }, + ) + ], + ) + ], + ) + """ + + def __init__( + self, + *, + name: str = 'langextract', + description: str = ( + 'Extracts structured information from unstructured' + ' text. Provide the text and a description of what' + ' to extract.' + ), + examples: Optional[list[lx.data.ExampleData]] = None, + model_id: str = 'gemini-2.5-flash', + api_key: Optional[str] = None, + extraction_passes: int = 1, + max_workers: int = 1, + max_char_buffer: int = 4000, + ): + super().__init__(name=name, description=description) + self._examples = examples or [] + self._model_id = model_id + self._api_key = api_key + self._extraction_passes = extraction_passes + self._max_workers = max_workers + self._max_char_buffer = max_char_buffer + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'text': types.Schema( + type=types.Type.STRING, + description=( + 'The unstructured text to extract information from.' + ), + ), + 'prompt_description': types.Schema( + type=types.Type.STRING, + description=( + 'A description of what kind of information to' + ' extract from the text.' + ), + ), + }, + required=['text', 'prompt_description'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + text = args.get('text') + prompt_description = args.get('prompt_description') + + if not text: + return {'error': 'The "text" parameter is required.'} + if not prompt_description: + return {'error': 'The "prompt_description" parameter is required.'} + + try: + extract_kwargs: dict[str, Any] = { + 'text_or_documents': text, + 'prompt_description': prompt_description, + 'examples': self._examples, + 'model_id': self._model_id, + 'extraction_passes': self._extraction_passes, + 'max_workers': self._max_workers, + 'max_char_buffer': self._max_char_buffer, + } + if self._api_key is not None: + extract_kwargs['api_key'] = self._api_key + + # lx.extract() is synchronous; run in a thread to avoid + # blocking the event loop. + result = await asyncio.to_thread(lx.extract, **extract_kwargs) + + extractions = [] + for extraction in result: + entry = { + 'extraction_class': extraction.extraction_class, + 'extraction_text': extraction.extraction_text, + } + if extraction.attributes: + entry['attributes'] = extraction.attributes + extractions.append(entry) + + return {'extractions': extractions} + + except Exception as e: + logger.error('LangExtract extraction failed: %s', e) + return {'error': f'Extraction failed: {e}'} + + +@dataclass +class LangExtractToolConfig: + """Configuration for LangExtractTool.""" + + name: str = 'langextract' + description: str = ( + 'Extracts structured information from unstructured text.' + ) + examples: list[lx.data.ExampleData] = field(default_factory=list) + model_id: str = 'gemini-2.5-flash' + api_key: Optional[str] = None + extraction_passes: int = 1 + max_workers: int = 1 + max_char_buffer: int = 4000 + + def build(self) -> LangExtractTool: + """Instantiate a LangExtractTool from this config.""" + return LangExtractTool( + name=self.name, + description=self.description, + examples=self.examples, + model_id=self.model_id, + api_key=self.api_key, + extraction_passes=self.extraction_passes, + max_workers=self.max_workers, + max_char_buffer=self.max_char_buffer, + ) diff --git a/tests/unittests/tools/__init__.py b/tests/unittests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unittests/tools/test_langextract_tool.py b/tests/unittests/tools/test_langextract_tool.py new file mode 100644 index 0000000..8db1617 --- /dev/null +++ b/tests/unittests/tools/test_langextract_tool.py @@ -0,0 +1,174 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +pytest.importorskip('langextract', reason='Requires langextract') + +from google.adk_community.tools.langextract_tool import LangExtractTool +from google.adk_community.tools.langextract_tool import LangExtractToolConfig + + +def test_langextract_tool_default_initialization(): + """Test that LangExtractTool initializes with correct defaults.""" + tool = LangExtractTool() + assert tool.name == 'langextract' + assert 'structured information' in tool.description + assert tool._model_id == 'gemini-2.5-flash' + assert tool._examples == [] + assert tool._extraction_passes == 1 + assert tool._max_workers == 1 + assert tool._max_char_buffer == 4000 + assert tool._api_key is None + + +def test_langextract_tool_custom_initialization(): + """Test that LangExtractTool accepts custom parameters.""" + import langextract as lx + + examples = [ + lx.data.ExampleData( + text='test text', + extractions=[ + lx.data.Extraction( + extraction_class='entity', + extraction_text='test', + ) + ], + ) + ] + tool = LangExtractTool( + name='my_extractor', + description='Custom extractor', + examples=examples, + model_id='gemini-2.0-flash', + api_key='test-key', + extraction_passes=2, + max_workers=4, + max_char_buffer=8000, + ) + assert tool.name == 'my_extractor' + assert tool.description == 'Custom extractor' + assert len(tool._examples) == 1 + assert tool._model_id == 'gemini-2.0-flash' + assert tool._api_key == 'test-key' + assert tool._extraction_passes == 2 + assert tool._max_workers == 4 + assert tool._max_char_buffer == 8000 + + +def test_langextract_tool_get_declaration(): + """Test that _get_declaration returns the correct schema.""" + tool = LangExtractTool() + declaration = tool._get_declaration() + assert declaration is not None + assert declaration.name == 'langextract' + assert declaration.parameters is not None + props = declaration.parameters.properties + assert 'text' in props + assert 'prompt_description' in props + assert 'text' in declaration.parameters.required + assert 'prompt_description' in declaration.parameters.required + + +@pytest.mark.asyncio +@patch('google.adk_community.tools.langextract_tool.lx') +async def test_langextract_tool_run_async(mock_lx): + """Test that run_async calls lx.extract and returns results.""" + mock_extraction = MagicMock() + mock_extraction.extraction_class = 'person' + mock_extraction.extraction_text = 'John' + mock_extraction.attributes = {'role': 'engineer'} + mock_lx.extract.return_value = [mock_extraction] + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'John is an engineer.', + 'prompt_description': 'Extract people.', + }, + tool_context=MagicMock(), + ) + + assert 'extractions' in result + assert len(result['extractions']) == 1 + assert result['extractions'][0]['extraction_class'] == 'person' + assert result['extractions'][0]['extraction_text'] == 'John' + assert result['extractions'][0]['attributes'] == {'role': 'engineer'} + mock_lx.extract.assert_called_once() + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_text(): + """Test that run_async returns error when text is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'prompt_description': 'Extract people.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'text' in result['error'] + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_prompt_description(): + """Test that run_async returns error when prompt_description is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'text': 'Some text.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'prompt_description' in result['error'] + + +@pytest.mark.asyncio +@patch('google.adk_community.tools.langextract_tool.lx') +async def test_langextract_tool_extraction_error(mock_lx): + """Test that run_async handles extraction errors gracefully.""" + mock_lx.extract.side_effect = RuntimeError('API error') + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'Some text.', + 'prompt_description': 'Extract stuff.', + }, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'Extraction failed' in result['error'] + + +def test_langextract_tool_config_build(): + """Test that LangExtractToolConfig.build() returns a LangExtractTool.""" + config = LangExtractToolConfig( + name='my_tool', + description='My custom extractor', + model_id='gemini-2.0-flash', + extraction_passes=3, + max_workers=2, + max_char_buffer=6000, + ) + tool = config.build() + assert isinstance(tool, LangExtractTool) + assert tool.name == 'my_tool' + assert tool.description == 'My custom extractor' + assert tool._model_id == 'gemini-2.0-flash' + assert tool._extraction_passes == 3 + assert tool._max_workers == 2 + assert tool._max_char_buffer == 6000