diff --git a/pyproject.toml b/pyproject.toml index 9bec96cb8c..54416d1dec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ test = [ "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool tests; chromadb/pypika fail on 3.12+ "kubernetes>=29.0.0", # For GkeCodeExecutor "langchain-community>=0.3.17", + "langextract>=0.1.0", # For LangExtractTool tests "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent "litellm>=1.75.5, <1.80.17", # For LiteLLM tests "llama-index-readers-file>=0.4.0", # For retrieval tests @@ -155,6 +156,7 @@ extensions = [ "crewai[tools];python_version>='3.11' and python_version<'3.12'", # For CrewaiTool; chromadb/pypika fail on 3.12+ "docker>=7.0.0", # For ContainerCodeExecutor "kubernetes>=29.0.0", # For GkeCodeExecutor + "langextract>=0.1.0", # For LangExtractTool "langgraph>=0.2.60, <0.4.8", # For LangGraphAgent "litellm>=1.75.5, <1.80.17", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it "llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex. diff --git a/src/google/adk/tools/langextract_tool.py b/src/google/adk/tools/langextract_tool.py new file mode 100644 index 0000000000..1a0e162ab7 --- /dev/null +++ b/src/google/adk/tools/langextract_tool.py @@ -0,0 +1,267 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Optional + +from google.genai import types +from typing_extensions import override + +from ..features import FeatureName +from ..features import is_feature_enabled +from .base_tool import BaseTool +from .tool_configs import BaseToolConfig +from .tool_configs import ToolArgsConfig +from .tool_context import ToolContext + +try: + import langextract as lx +except ImportError as e: + raise ImportError( + "LangExtract tools require pip install 'google-adk[extensions]'." + ) from e + +logger = logging.getLogger('google_adk.' + __name__) + + +class LangExtractTool(BaseTool): + """A tool that extracts structured information from text using LangExtract. + + This tool wraps the langextract library to enable LLM agents to extract + structured data (entities, attributes, relationships) from unstructured + text. The agent provides the text to extract from and a description of + what to extract; other parameters are pre-configured at construction time. + + Args: + name: The name of the tool. Defaults to 'langextract'. + description: The description of the tool shown to the LLM. + examples: Optional list of langextract ExampleData for few-shot + extraction guidance. + model_id: The model ID for langextract to use internally. + Defaults to 'gemini-2.5-flash'. + api_key: Optional API key for langextract. If None, uses the + LANGEXTRACT_API_KEY environment variable. + extraction_passes: Number of extraction passes. Defaults to 1. + max_workers: Maximum worker threads for langextract. + Defaults to 1. + max_char_buffer: Maximum character buffer size for text + chunking. Defaults to 4000. + + Examples:: + + from google.adk.tools.langextract_tool import LangExtractTool + import langextract as lx + + examples = [ + lx.data.ExampleData( + text="John is a software engineer at Google.", + extractions=[ + lx.data.Extraction( + extraction_class="person", + extraction_text="John", + attributes={ + "role": "software engineer", + "company": "Google", + }, + ) + ], + ) + ] + + tool = LangExtractTool( + name='extract_people', + description='Extract person entities from text.', + examples=examples, + ) + """ + + def __init__( + self, + *, + name: str = 'langextract', + description: str = ( + 'Extracts structured information from unstructured' + ' text. Provide the text and a description of what' + ' to extract.' + ), + examples: Optional[list[lx.data.ExampleData]] = None, + model_id: str = 'gemini-2.5-flash', + api_key: Optional[str] = None, + extraction_passes: int = 1, + max_workers: int = 1, + max_char_buffer: int = 4000, + ): + super().__init__(name=name, description=description) + self._examples = examples or [] + self._model_id = model_id + self._api_key = api_key + self._extraction_passes = extraction_passes + self._max_workers = max_workers + self._max_char_buffer = max_char_buffer + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + if is_feature_enabled(FeatureName.JSON_SCHEMA_FOR_FUNC_DECL): + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters_json_schema={ + 'type': 'object', + 'properties': { + 'text': { + 'type': 'string', + 'description': ( + 'The unstructured text to extract information from.' + ), + }, + 'prompt_description': { + 'type': 'string', + 'description': ( + 'A description of what kind of' + ' information to extract from' + ' the text.' + ), + }, + }, + 'required': ['text', 'prompt_description'], + }, + ) + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + 'text': types.Schema( + type=types.Type.STRING, + description=( + 'The unstructured text to extract information from.' + ), + ), + 'prompt_description': types.Schema( + type=types.Type.STRING, + description=( + 'A description of what kind of' + ' information to extract from' + ' the text.' + ), + ), + }, + required=['text', 'prompt_description'], + ), + ) + + @override + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + text = args.get('text') + prompt_description = args.get('prompt_description') + + if not text: + return {'error': 'The "text" parameter is required.'} + if not prompt_description: + return { + 'error': 'The "prompt_description" parameter is required.', + } + + try: + extract_kwargs: dict[str, Any] = { + 'text_or_documents': text, + 'prompt_description': prompt_description, + 'examples': self._examples, + 'model_id': self._model_id, + 'extraction_passes': self._extraction_passes, + 'max_workers': self._max_workers, + 'max_char_buffer': self._max_char_buffer, + } + if self._api_key is not None: + extract_kwargs['api_key'] = self._api_key + + # lx.extract() is synchronous; run in a thread to avoid + # blocking the event loop. + result = await asyncio.to_thread(lx.extract, **extract_kwargs) + + extractions = [] + for extraction in result: + entry = { + 'extraction_class': extraction.extraction_class, + 'extraction_text': extraction.extraction_text, + } + if extraction.attributes: + entry['attributes'] = extraction.attributes + extractions.append(entry) + + return {'extractions': extractions} + + except Exception as e: + logger.error('LangExtract extraction failed: %s', e) + return {'error': f'Extraction failed: {e}'} + + @override + @classmethod + def from_config( + cls: type[LangExtractTool], + config: ToolArgsConfig, + config_abs_path: str, + ) -> LangExtractTool: + from ..agents import config_agent_utils + + langextract_config = LangExtractToolConfig.model_validate( + config.model_dump() + ) + + init_kwargs = langextract_config.model_dump() + if examples_path := init_kwargs.get('examples'): + init_kwargs['examples'] = ( + config_agent_utils.resolve_fully_qualified_name( + examples_path + ) + ) + else: + init_kwargs['examples'] = [] + + return cls(**init_kwargs) + + +class LangExtractToolConfig(BaseToolConfig): + """Configuration for LangExtractTool when loaded from YAML config.""" + + name: str = 'langextract' + """The name of the tool.""" + + description: str = 'Extracts structured information from unstructured text.' + """The description of the tool.""" + + examples: str = '' + """Fully qualified path to a list of ExampleData instances.""" + + model_id: str = 'gemini-2.5-flash' + """The model ID for langextract.""" + + api_key: Optional[str] = None + """Optional API key for langextract.""" + + extraction_passes: int = 1 + """Number of extraction passes.""" + + max_workers: int = 1 + """Maximum worker threads.""" + + max_char_buffer: int = 4000 + """Maximum character buffer size.""" diff --git a/tests/unittests/tools/test_langextract_tool.py b/tests/unittests/tools/test_langextract_tool.py new file mode 100644 index 0000000000..e26301077e --- /dev/null +++ b/tests/unittests/tools/test_langextract_tool.py @@ -0,0 +1,190 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +pytest.importorskip('langextract', reason='Requires langextract') + +from google.adk.features import FeatureName +from google.adk.features._feature_registry import temporary_feature_override +from google.adk.tools.langextract_tool import LangExtractTool +from google.adk.tools.langextract_tool import LangExtractToolConfig + + +def test_langextract_tool_default_initialization(): + """Test that LangExtractTool initializes with correct defaults.""" + tool = LangExtractTool() + assert tool.name == 'langextract' + assert 'structured information' in tool.description + assert tool._model_id == 'gemini-2.5-flash' + assert tool._examples == [] + assert tool._extraction_passes == 1 + assert tool._max_workers == 1 + assert tool._max_char_buffer == 4000 + assert tool._api_key is None + + +def test_langextract_tool_custom_initialization(): + """Test that LangExtractTool accepts custom parameters.""" + import langextract as lx + + examples = [ + lx.data.ExampleData( + text='test text', + extractions=[ + lx.data.Extraction( + extraction_class='entity', + extraction_text='test', + ) + ], + ) + ] + tool = LangExtractTool( + name='my_extractor', + description='Custom extractor', + examples=examples, + model_id='gemini-2.0-flash', + api_key='test-key', + extraction_passes=2, + max_workers=4, + max_char_buffer=8000, + ) + assert tool.name == 'my_extractor' + assert tool.description == 'Custom extractor' + assert len(tool._examples) == 1 + assert tool._model_id == 'gemini-2.0-flash' + assert tool._api_key == 'test-key' + assert tool._extraction_passes == 2 + assert tool._max_workers == 4 + assert tool._max_char_buffer == 8000 + + +@pytest.mark.parametrize('json_schema_enabled', [True, False]) +def test_langextract_tool_get_declaration(json_schema_enabled): + """Test that _get_declaration returns the correct schema.""" + with temporary_feature_override( + FeatureName.JSON_SCHEMA_FOR_FUNC_DECL, json_schema_enabled + ): + tool = LangExtractTool() + declaration = tool._get_declaration() + assert declaration is not None + assert declaration.name == 'langextract' + + if json_schema_enabled: + params = declaration.parameters_json_schema + assert params is not None + props = params['properties'] + required = params['required'] + else: + params = declaration.parameters + assert params is not None + props = params.properties + required = params.required + + assert 'text' in props + assert 'prompt_description' in props + assert 'text' in required + assert 'prompt_description' in required + + +@pytest.mark.asyncio +@patch('google.adk.tools.langextract_tool.lx') +async def test_langextract_tool_run_async(mock_lx): + """Test that run_async calls lx.extract and returns results.""" + mock_extraction = MagicMock() + mock_extraction.extraction_class = 'person' + mock_extraction.extraction_text = 'John' + mock_extraction.attributes = {'role': 'engineer'} + mock_lx.extract.return_value = [mock_extraction] + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'John is an engineer.', + 'prompt_description': 'Extract people.', + }, + tool_context=MagicMock(), + ) + + assert 'extractions' in result + assert len(result['extractions']) == 1 + assert result['extractions'][0]['extraction_class'] == 'person' + assert result['extractions'][0]['extraction_text'] == 'John' + assert result['extractions'][0]['attributes'] == {'role': 'engineer'} + mock_lx.extract.assert_called_once() + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_text(): + """Test that run_async returns error when text is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'prompt_description': 'Extract people.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'text' in result['error'] + + +@pytest.mark.asyncio +async def test_langextract_tool_missing_prompt_description(): + """Test that run_async returns error when prompt_description is missing.""" + tool = LangExtractTool() + result = await tool.run_async( + args={'text': 'Some text.'}, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'prompt_description' in result['error'] + + +@pytest.mark.asyncio +@patch('google.adk.tools.langextract_tool.lx') +async def test_langextract_tool_extraction_error(mock_lx): + """Test that run_async handles extraction errors gracefully.""" + mock_lx.extract.side_effect = RuntimeError('API error') + + tool = LangExtractTool() + result = await tool.run_async( + args={ + 'text': 'Some text.', + 'prompt_description': 'Extract stuff.', + }, + tool_context=MagicMock(), + ) + assert 'error' in result + assert 'Extraction failed' in result['error'] + + +def test_langextract_tool_config(): + """Test that LangExtractToolConfig validates correctly.""" + config = LangExtractToolConfig( + name='my_tool', + description='My custom extractor', + model_id='gemini-2.0-flash', + extraction_passes=3, + max_workers=2, + max_char_buffer=6000, + ) + assert config.name == 'my_tool' + assert config.description == 'My custom extractor' + assert config.model_id == 'gemini-2.0-flash' + assert config.extraction_passes == 3 + assert config.max_workers == 2 + assert config.max_char_buffer == 6000 + assert config.examples == '' + assert config.api_key is None