Skip to content

infomaniak Pipeline <think> tag #16

@Shayano

Description

@Shayano

When using the reasoning model (Deepseek-R1-distilled-qwen-32B) the tag is not transmitted as a response by the model, so the reflexion process is displayed in the chat as a classic response and then the reflexion ends with before giving the answer to the request.

I've worked around the problem in a “dirty” way by adding the tag in the response in the reasoning use case. I haven't done any PR because I'm not sure this is the best way to proceed.


from typing import List, Union, Generator, Iterator, Optional, Dict, Any, AsyncGenerator
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from starlette.background import BackgroundTask
from open_webui.env import AIOHTTP_CLIENT_TIMEOUT, SRC_LOG_LEVELS
from cryptography.fernet import Fernet, InvalidToken
import aiohttp
import json
import os
import logging
import base64
import hashlib
from pydantic_core import core_schema

# Simplified encryption implementation with automatic handling
class EncryptedStr(str):
    """A string type that automatically handles encryption/decryption"""
    
    @classmethod
    def _get_encryption_key(cls) -> Optional[bytes]:
        """
        Generate encryption key from WEBUI_SECRET_KEY if available
        Returns None if no key is configured
        """
        secret = os.getenv("WEBUI_SECRET_KEY")
        if not secret:
            return None
            
        hashed_key = hashlib.sha256(secret.encode()).digest()
        return base64.urlsafe_b64encode(hashed_key)
    
    @classmethod
    def encrypt(cls, value: str) -> str:
        """
        Encrypt a string value if a key is available
        Returns the original value if no key is available
        """
        if not value or value.startswith("encrypted:"):
            return value
        
        key = cls._get_encryption_key()
        if not key:  # No encryption if no key
            return value
            
        f = Fernet(key)
        encrypted = f.encrypt(value.encode())
        return f"encrypted:{encrypted.decode()}"
    
    @classmethod
    def decrypt(cls, value: str) -> str:
        """
        Decrypt an encrypted string value if a key is available
        Returns the original value if no key is available or decryption fails
        """
        if not value or not value.startswith("encrypted:"):
            return value
        
        key = cls._get_encryption_key()
        if not key:  # No decryption if no key
            return value[len("encrypted:"):]  # Return without prefix
        
        try:
            encrypted_part = value[len("encrypted:"):]
            f = Fernet(key)
            decrypted = f.decrypt(encrypted_part.encode())
            return decrypted.decode()
        except (InvalidToken, Exception):
            return value
            
    # Pydantic integration
    @classmethod
    def __get_pydantic_core_schema__(
        cls, _source_type: Any, _handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        return core_schema.union_schema([
            core_schema.is_instance_schema(cls),
            core_schema.chain_schema([
                core_schema.str_schema(),
                core_schema.no_info_plain_validator_function(
                    lambda value: cls(cls.encrypt(value) if value else value)
                ),
            ]),
        ],
        serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: str(instance))
        )
    
    def get_decrypted(self) -> str:
        """Get the decrypted value"""
        return self.decrypt(self)


# Helper functions
async def cleanup_response(
    response: Optional[aiohttp.ClientResponse],
    session: Optional[aiohttp.ClientSession],
) -> None:
    """
    Clean up the response and session objects.
    
    Args:
        response: The ClientResponse object to close
        session: The ClientSession object to close
    """
    if response:
        response.close()
    if session:
        await session.close()

class Pipe:
    # Environment variables for API key, endpoint, and optional model
    class Valves(BaseModel):
        # API key for Infomaniak - automatically encrypted
        INFOMANIAK_API_KEY: EncryptedStr = Field(
            default=os.getenv("INFOMANIAK_API_KEY", "API_KEY"),
            description="API key for Infomaniak AI TOOLS API"
        )
        # Product ID for Infomaniak
        INFOMANIAK_PRODUCT_ID: int = Field(
            default=os.getenv("INFOMANIAK_PRODUCT_ID", 50070),
            description="Product ID for Infomaniak AI TOOLS API"
        )
        # Base URL for Infomaniak API
        INFOMANIAK_BASE_URL: str = Field(
            default=os.getenv("INFOMANIAK_BASE_URL", "https://api.infomaniak.com"),
            description="Base URL for Infomaniak API"
        )
        # Prefix for model names
        NAME_PREFIX: str = Field(
            default="Infomaniak: ",
            description="Prefix to be added before model names"
        )
        # Enable reasoning detection for DeepSeek model
        ENABLE_REASONING: bool = Field(
            default=True,
            description="Enable reasoning detection for DeepSeek model"
        )

    def __init__(self):
        self.type = "manifold"
        self.valves = self.Valves()
        self.name: str = self.valves.NAME_PREFIX
        self.data_prefix = "data:"

    def validate_environment(self) -> None:
        """
        Validates that required environment variables are set.
        
        Raises:
            ValueError: If required environment variables are not set.
        """
        # Access the decrypted API key
        api_key = self.valves.INFOMANIAK_API_KEY.get_decrypted()
        if not api_key:
            raise ValueError("INFOMANIAK_API_KEY is not set!")
        if not self.valves.INFOMANIAK_PRODUCT_ID:
            raise ValueError("INFOMANIAK_PRODUCT_ID is not set!")
        if not self.valves.INFOMANIAK_BASE_URL:
            raise ValueError("INFOMANIAK_BASE_URL is not set!")

    def get_headers(self) -> Dict[str, str]:
        """
        Constructs the headers for the API request.
        
        Returns:
            Dictionary containing the required headers for the API request.
        """
        # Access the decrypted API key
        api_key = self.valves.INFOMANIAK_API_KEY.get_decrypted()
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }
        return headers

    def get_api_url(self, endpoint: str = "chat/completions") -> str:
        """
        Constructs the API URL for Infomaniak requests.
        
        Args:
            endpoint: The API endpoint to use
            
        Returns:
            Full API URL
        """
        return f"{self.valves.INFOMANIAK_BASE_URL}/1/ai/{self.valves.INFOMANIAK_PRODUCT_ID}/openai/{endpoint}"

    def validate_body(self, body: Dict[str, Any]) -> None:
        """
        Validates the request body to ensure required fields are present.
        
        Args:
            body: The request body to validate
            
        Raises:
            ValueError: If required fields are missing or invalid.
        """
        if "messages" not in body or not isinstance(body["messages"], list):
            raise ValueError("The 'messages' field is required and must be a list.")

    async def get_infomaniak_models(self) -> List[Dict[str, str]]:
        """
        Returns a list of Infomaniak AI LLM models.
    
        Returns:
            List of dictionaries containing model id and name.
        """
        log = logging.getLogger("infomaniak_ai_tools.get_models")
        log.setLevel(SRC_LOG_LEVELS["OPENAI"])
    
        headers = self.get_headers()
        models = []
    
        try:
            async with aiohttp.ClientSession() as session:
                async with session.get(
                    url=f"{self.valves.INFOMANIAK_BASE_URL}/1/ai/models",
                    headers=headers
                ) as resp:
                    if resp.status == 200:
                        data = await resp.json()
                        if data.get("result") == "success" and "data" in data:
                            models_data = data["data"]
                            if isinstance(models_data, list):
                                for item in models_data:
                                    if not isinstance(item, dict):
                                        log.error(f"Expected item to be dict but got: {type(item).__name__}")
                                        continue
                                    if item.get("type") == "llm":  # only include llm models
                                        models.append({
                                            "id": item.get("name", ""),
                                            "name": item.get("description", item.get("name", "")),
                                            # Profile image and description are currently not working in Open WebUI
                                            "meta": {
                                                "profile_image_url": item.get("logo_url", ""),
                                                "description": item.get("documentation_link", "")
                                            }
                                        })
                                return models
                            else:
                                log.error("Expected 'data' to be a list but received a non-list value.")
                    log.error(f"Failed to get Infomaniak models: {await resp.text()}")
        except Exception as e:
            log.exception(f"Error getting Infomaniak models: {str(e)}")
        
        # Default model if API call fails
        return [{"id": f"{self.valves.INFOMANIAK_PRODUCT_ID}", "name": "Infomaniak: LLM API"}]

    async def pipes(self) -> List[Dict[str, str]]:
        """
        Returns a list of available pipes based on configuration.
        
        Returns:
            List of dictionaries containing pipe id and name.
        """
        self.validate_environment()
        return await self.get_infomaniak_models()

    async def pipe(self, body: Dict[str, Any], __event_emitter__=None) -> Union[str, Generator, Iterator, Dict[str, Any], StreamingResponse]:
        """
        Main method for sending requests to the Infomaniak AI endpoint.
        
        Args:
            body: The request body containing messages and other parameters
            __event_emitter__: Optional event emitter for status updates (not used but needed for compatibility)
            
        Returns:
            Response from Infomaniak AI API, which could be a string, dictionary or streaming response
        """
        log = logging.getLogger("infomaniak_ai_tools.pipe")
        log.setLevel(SRC_LOG_LEVELS["OPENAI"])

        # Validate the request body
        self.validate_body(body)

        # Construct headers
        headers = self.get_headers()

        # Check if this is the reasoning model
        model_id = body.get("model", "")
        # More precise verification to correctly identify the reasoning model
        is_reasoning_model = self.valves.ENABLE_REASONING and (
            "reasoning" in model_id.lower() or 
            model_id.lower() == "infomaniak_ai_tools.reasoning" or
            "deepseek-reasoner" in model_id.lower()
        )
        
        if is_reasoning_model:
            log.info(f"Detected reasoning model: {model_id}")
        
        # Filter allowed parameters
        allowed_params = {
            "frequency_penalty",
            "logit_bias",
            "logprobs",
            "max_tokens",
            "messages",
            "model",
            "n",
            "presence_penalty",
            "profile_type",
            "seed",
            "stop",
            "stream",
            "temperature",
            "top_logprobs",
            "top_p"
        }
        filtered_body = {k: v for k, v in body.items() if k in allowed_params}

        # Handle model extraction for Infomaniak
        if "model" in filtered_body and filtered_body["model"]:
            # Extract model ID
            filtered_body["model"] = filtered_body["model"].split(".", 1)[1] if "." in filtered_body["model"] else filtered_body["model"]

        # Special handling for reasoning model in streaming mode
        if is_reasoning_model and filtered_body.get("stream", False):
            # Create a session for the reasoning model
            session = aiohttp.ClientSession(
                trust_env=True,
                timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
            )
            
            # Wrapper to convert processing logic into streamable, compatible content
            async def reasoning_content_generator():
                try:
                    gen = self._handle_reasoning_stream(filtered_body, headers, session)
                    async for chunk in gen:
                        if isinstance(chunk, str):
                            if chunk.startswith('{"error":'):
                                # Si c'est une erreur JSON, on l'envoie directement
                                yield f"data: {chunk}\n\n"
                                return
                            else:
                                # Format compatible with the format expected by StreamingResponse
                                yield f"data: {json.dumps({'choices': [{'delta': {'content': chunk}}]})}\n\n"
                    yield "data: [DONE]\n\n"
                except Exception as e:
                    log.exception(f"Error in reasoning content generator: {e}")
                    error_msg = json.dumps({"error": f"Error processing reasoning stream: {str(e)}"})
                    yield f"data: {error_msg}\n\n"
                    yield "data: [DONE]\n\n"
            
            return StreamingResponse(
                reasoning_content_generator(),
                media_type="text/event-stream",
                background=BackgroundTask(cleanup_response, response=None, session=session)
            )

        # Convert the modified body back to JSON
        payload = json.dumps(filtered_body)

        request = None
        session = None
        streaming = False
        response = None

        try:
            session = aiohttp.ClientSession(
                trust_env=True,
                timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
            )

            api_url = self.get_api_url()
            request = await session.request(
                method="POST",
                url=api_url,
                data=payload,
                headers=headers,
            )

            # Check if response is SSE
            if "text/event-stream" in request.headers.get("Content-Type", ""):
                streaming = True
                return StreamingResponse(
                    request.content,
                    status_code=request.status,
                    headers=dict(request.headers),
                    background=BackgroundTask(
                        cleanup_response, response=request, session=session
                    ),
                )
            else:
                try:
                    response = await request.json()
                except Exception as e:
                    log.error(f"Error parsing JSON response: {e}")
                    response = await request.text()

                request.raise_for_status()
                return response

        except Exception as e:
            log.exception(f"Error in Infomaniak AI request: {e}")

            detail = f"Exception: {str(e)}"
            if isinstance(response, dict):
                if "error" in response:
                    detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}"
            elif isinstance(response, str):
                detail = response

            return f"Error: {detail}"
        finally:
            if not streaming and session:
                if request:
                    request.close()
                await session.close()

    async def _handle_reasoning_stream(self, body: Dict[str, Any], headers: Dict[str, str], session: aiohttp.ClientSession) -> AsyncGenerator[str, None]:
        """
        Special handler for the DeepSeek Reasoning model with thinking functionality.
        
        Args:
            body: The request body
            headers: Request headers
            session: The ClientSession object for the reasoning model
            
        Yields:
            Processed stream content with thinking tags
        """
        log = logging.getLogger("infomaniak_ai_tools.reasoning")
        log.setLevel(SRC_LOG_LEVELS["OPENAI"])
        
        log.info(f"Processing reasoning stream for model: {body.get('model', 'unknown')}")
        
        # Convert body to JSON
        payload = json.dumps(body)
        
        # State variables to track thinking process
        in_thinking = False
        model_added_think_tag = False  # To detect if the model added its own </think> tag
        
        try:
            async with session.post(
                url=self.get_api_url(),
                data=payload,
                headers=headers
            ) as response:
                if response.status != 200:
                    error = await response.text()
                    log.error(f"Error from API: {error}")
                    yield json.dumps({"error": f"HTTP {response.status}: {error}"})
                    return
                
                # Force thinking mode at the start for the reasoning model
                in_thinking = True
                yield "<think>\n"
                log.info("Forcing thinking mode at start")
                
                # Process streaming response
                async for chunk in response.content:
                    line = chunk.decode('utf-8')
                    if not line.startswith(self.data_prefix):
                        continue
                        
                    # Extract JSON data
                    json_str = line[len(self.data_prefix):].strip()
                    if json_str == "[DONE]":
                        # Make sure the thinking tag is closed at the end, but only if the model hasn't already done it
                        if in_thinking and not model_added_think_tag:
                            log.info("Closing thinking tag at [DONE]")
                            yield "\n</think>\n\n"
                        return
                        
                    try:
                        data = json.loads(json_str)
                    except json.JSONDecodeError:
                        continue
                        
                    # Extract content from the chunk
                    content = ""
                    
                    # Parse from choices[0].delta.content
                    choices = data.get("choices", [])
                    if choices and len(choices) > 0:
                        choice = choices[0]
                        
                        # Check for finish reason
                        if choice.get("finish_reason"):
                            # If we're still in thinking mode when response ends, close the thinking tag
                            if in_thinking and not model_added_think_tag:
                                log.info("Closing thinking tag at finish_reason")
                                yield "\n</think>\n\n"
                                in_thinking = False
                            return
                        
                        # Try to get content from delta
                        delta = choice.get("delta", {})
                        if isinstance(delta, dict):
                            content = delta.get("content", "")
                        
                        # Try other possible locations for content
                        if not content:
                            # If delta itself is a string
                            if isinstance(delta, str):
                                content = delta
                            # Check if choice has text field
                            elif "text" in choice:
                                content = choice["text"]
                            # Check if choice has a message with content
                            elif "message" in choice:
                                message = choice["message"]
                                if isinstance(message, dict):
                                    content = message.get("content", "")
                            # Check each field in delta for string value
                            elif isinstance(delta, dict):
                                for key, value in delta.items():
                                    if isinstance(value, str) and value:
                                        content = value
                                        break
                    
                    if not content:
                        continue
                    
                    # Check if the content already contains a </think> tag
                    if "</think>" in content:
                        log.info("Model added its own </think> tag")
                        model_added_think_tag = True
                        in_thinking = False
                        # No need to add our own closing tag
                    
                    # Output the content directly
                    yield content
                    
        except Exception as e:
            log.exception(f"Error in reasoning stream: {e}")
            # Make sure the thinking tag is closed in case of error
            if in_thinking and not model_added_think_tag:
                log.info("Closing thinking tag on error")
                yield "\n</think>\n\n"
            yield json.dumps({"error": f"Exception: {str(e)}"})

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions