From a31ecaef618aac89467d544c894d6edf6cbd7953 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 27 Aug 2025 16:37:02 +0100 Subject: [PATCH 01/31] switch to api_base --- src/client/content/config/tabs/models.py | 4 +-- src/client/content/tools/tabs/split_embed.py | 2 +- src/client/mcp/rag/optimizer_utils/config.py | 4 +-- src/common/schema.py | 2 +- src/server/api/core/models.py | 2 +- src/server/api/utils/models.py | 10 +++--- src/server/bootstrap/models.py | 36 +++++++++---------- .../client/content/config/tabs/test_models.py | 2 +- .../content/tools/tabs/test_split_embed.py | 24 ++++++------- tests/server/test_endpoints_models.py | 12 +++---- 10 files changed, 49 insertions(+), 49 deletions(-) diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 8fccd744..97fa76f4 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -116,7 +116,7 @@ def edit_model(model_type: str, action: Literal["add", "edit"], model_id: str = key="add_model_provider", disabled=action == "edit", ) - model["url"] = st.text_input( + model["api_base"] = st.text_input( "Provider URL:", help=help_text.help_dict["model_url"], key="add_model_url", @@ -234,7 +234,7 @@ def render_model_rows(model_type: str) -> None: ) col4.text_input( "Server", - value=model["url"], + value=model["api_base"], key=f"{model_type}_{model_id}_server", label_visibility="collapsed", disabled=True, diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index cfa9a47e..8704e0d9 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -154,7 +154,7 @@ def display_split_embed() -> None: index=0, key="selected_embed_model", ) - embed_url = embed_models_enabled[embed_request.model]["url"] + embed_url = embed_models_enabled[embed_request.model]["api_base"] st.write(f"Embedding Server: {embed_url}") is_embed_accessible, embed_err_msg = functions.is_url_accessible(embed_url) if not is_embed_accessible: diff --git a/src/client/mcp/rag/optimizer_utils/config.py b/src/client/mcp/rag/optimizer_utils/config.py index ea6ca93b..4f113a37 100644 --- a/src/client/mcp/rag/optimizer_utils/config.py +++ b/src/client/mcp/rag/optimizer_utils/config.py @@ -26,7 +26,7 @@ def get_llm(data): llm = {} llm_config = data["ll_model_config"][data["user_settings"]["ll_model"]["model"]] provider = llm_config["provider"] - url = llm_config["url"] + url = llm_config["api_base"] api_key = llm_config["api_key"] model = data["user_settings"]["ll_model"]["model"] logging.info(f"CHAT_MODEL: {model} {provider} {url} {api_key}") @@ -42,7 +42,7 @@ def get_embeddings(data): embeddings = {} model = data["user_settings"]["vector_search"]["model"] provider = data["embed_model_config"][model]["provider"] - url = data["embed_model_config"][model]["url"] + url = data["embed_model_config"][model]["api_base"] api_key = data["embed_model_config"][model]["api_key"] logging.info(f"EMBEDDINGS: {model} {provider} {url} {api_key}") embeddings = {} diff --git a/src/common/schema.py b/src/common/schema.py index dc7917a1..9fac3584 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -130,7 +130,7 @@ class ModelAccess(BaseModel): """Patch'able Model Parameters""" enabled: Optional[bool] = Field(default=False, description="Model is available for use.") - url: Optional[str] = Field(default=None, description="URL to Model API.") + api_base: Optional[str] = Field(default=None, description="Model API Base URL.") api_key: Optional[str] = Field(default=None, description="Model API Key.", json_schema_extra={"sensitive": True}) diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index 5f289c56..b9cdda89 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -79,7 +79,7 @@ def create_model(model: Model, check_url: bool = True) -> Model: False, ) model.openai_compat = openai_compat - if check_url and model.url and not is_url_accessible(model.url)[0]: + if check_url and model.api_base and not is_url_accessible(model.api_base)[0]: model.enabled = False model_objects.append(model) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 0b0958d4..acc9e6c4 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -29,7 +29,7 @@ def update_model(model_id: schema.ModelIdType, payload: schema.Model) -> schema. """Update an existing Model definition""" model_upd = core_models.get_model(model_id=model_id) - if payload.enabled and not is_url_accessible(model_upd.url)[0]: + if payload.enabled and not is_url_accessible(model_upd.api_base)[0]: model_upd.enabled = False raise core_models.URLUnreachableError("Model: Unable to update. API URL is inaccessible.") @@ -67,7 +67,7 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model model_dict["id"] = model["model_name"] model_dict["enabled"] = True - model_dict["url"] = f"https://inference.generativeai.{config.genai_region}.oci.oraclecloud.com" + model_dict["api_base"] = f"https://inference.generativeai.{config.genai_region}.oci.oraclecloud.com" # if model["vendor"] == "cohere": model_dict["openai_compat"] = False # Create the Model @@ -102,7 +102,7 @@ def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giska kwargs = { "model_provider": "openai" if provider == "openai_compatible" else provider, "model": full_model_config["id"], - "base_url": full_model_config["url"], + "base_url": full_model_config["api_base"], "temperature": full_model_config["temperature"], "max_tokens": full_model_config["max_completion_tokens"], **common_params, @@ -129,7 +129,7 @@ def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giska kwargs = { "provider": "openai" if provider == "openai_compatible" else provider, "model": full_model_config["id"], - "base_url": full_model_config["url"], + "base_url": full_model_config["api_base"], } if full_model_config.get("api_key"): # only add if set kwargs["api_key"] = full_model_config["api_key"] @@ -144,7 +144,7 @@ def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giska if giskard: logger.debug("Creating Giskard Client") giskard_key = full_model_config["api_key"] or "giskard" - _client = OpenAI(api_key=giskard_key, base_url=full_model_config["url"]) + _client = OpenAI(api_key=giskard_key, base_url=full_model_config["api_base"]) client = OpenAIClient(model=full_model_config["id"], client=_client) logger.debug("Configured Client: %s", vars(client)) diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index e135b66e..6f92b9d2 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -30,7 +30,7 @@ def main() -> list[Model]: "provider": "cohere", "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, - "url": "https://api.cohere.ai", + "api_base": "https://api.cohere.ai", "context_length": 127072, "temperature": 0.3, "max_completion_tokens": 4096, @@ -43,7 +43,7 @@ def main() -> list[Model]: "provider": "openai", "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, - "url": "https://api.openai.com/v1", + "api_base": "https://api.openai.com/v1", "context_length": 127072, "temperature": 1.0, "max_completion_tokens": 4096, @@ -56,7 +56,7 @@ def main() -> list[Model]: "provider": "perplexity", "api_key": os.environ.get("PPLX_API_KEY", default=""), "openai_compat": True, - "url": "https://api.perplexity.ai", + "api_base": "https://api.perplexity.ai", "context_length": 127072, "temperature": 0.2, "max_completion_tokens": 28000, @@ -69,7 +69,7 @@ def main() -> list[Model]: "provider": "openai_compatible", "api_key": "", "openai_compat": True, - "url": "http://localhost:1234/v1", + "api_base": "http://localhost:1234/v1", "context_length": 131072, "temperature": 1.0, "max_completion_tokens": 4096, @@ -82,7 +82,7 @@ def main() -> list[Model]: "provider": "ollama", "api_key": "", "openai_compat": True, - "url": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), + "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "context_length": 131072, "temperature": 1.0, "max_completion_tokens": 2048, @@ -96,7 +96,7 @@ def main() -> list[Model]: "provider": "ollama", "api_key": "", "openai_compat": True, - "url": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), + "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "context_length": 131072, "temperature": 1.0, "max_completion_tokens": 2048, @@ -107,7 +107,7 @@ def main() -> list[Model]: "enabled": os.getenv("ON_PREM_HF_URL") is not None, "type": "embed", "provider": "huggingface", - "url": os.environ.get("ON_PREM_HF_URL", default="http://127.0.0.1:8080"), + "api_base": os.environ.get("ON_PREM_HF_URL", default="http://127.0.0.1:8080"), "api_key": "", "openai_compat": True, "max_chunk_size": 512, @@ -117,7 +117,7 @@ def main() -> list[Model]: "enabled": os.getenv("OPENAI_API_KEY") is not None, "type": "embed", "provider": "openai_compatible", - "url": "https://api.openai.com/v1", + "api_base": "https://api.openai.com/v1", "api_key": os.environ.get("OPENAI_API_KEY", default=""), "openai_compat": True, "max_chunk_size": 8191, @@ -127,7 +127,7 @@ def main() -> list[Model]: "enabled": os.getenv("COHERE_API_KEY") is not None, "type": "embed", "provider": "cohere", - "url": "https://api.cohere.ai", + "api_base": "https://api.cohere.ai", "api_key": os.environ.get("COHERE_API_KEY", default=""), "openai_compat": False, "max_chunk_size": 512, @@ -137,7 +137,7 @@ def main() -> list[Model]: "enabled": False, "type": "embed", "provider": "openai_compatible", - "url": "http://localhost:1234/v1", + "api_base": "http://localhost:1234/v1", "api_key": "", "openai_compat": True, "max_chunk_size": 8192, @@ -148,7 +148,7 @@ def main() -> list[Model]: "enabled": os.getenv("ON_PREM_OLLAMA_URL") is not None, "type": "embed", "provider": "ollama", - "url": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), + "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "api_key": "", "openai_compat": True, "max_chunk_size": 8192, @@ -216,31 +216,31 @@ def values_differ(a, b): model["enabled"] = True elif provider == "oci" and os.getenv("OCI_GENAI_SERVICE_ENDPOINT"): - old_url = model.get("url", "") + old_url = model.get("api_base", "") new_url = os.environ["OCI_GENAI_SERVICE_ENDPOINT"] if old_url != new_url: logger.info( "Overriding 'url' for model '%s' with OCI_GENAI_SERVICE_ENDPOINT environment variable", model_id ) - model["url"] = new_url + model["api_base"] = new_url overridden = True model["enabled"] = True elif provider == "ollama" and os.getenv("ON_PREM_OLLAMA_URL"): - old_url = model.get("url", "") + old_url = model.get("api_base", "") new_url = os.environ["ON_PREM_OLLAMA_URL"] if old_url != new_url: logger.info("Overriding 'url' for model '%s' with ON_PREM_OLLAMA_URL environment variable", model_id) - model["url"] = new_url + model["api_base"] = new_url overridden = True model["enabled"] = True elif provider == "huggingface" and os.getenv("ON_PREM_HF_URL"): - old_url = model.get("url", "") + old_url = model.get("api_base", "") new_url = os.environ["ON_PREM_HF_URL"] if old_url != new_url: logger.info("Overriding 'url' for model '%s' with ON_PREM_HF_URL environment variable", model_id) - model["url"] = new_url + model["api_base"] = new_url overridden = True model["enabled"] = True @@ -251,7 +251,7 @@ def values_differ(a, b): url_access_cache = {} for model in models_list: - url = model["url"] + url = model["api_base"] if model["enabled"]: if url not in url_access_cache: logger.debug("Testing %s URL: %s", model["id"], url) diff --git a/tests/client/content/config/tabs/test_models.py b/tests/client/content/config/tabs/test_models.py index 2833efb9..3c52070a 100644 --- a/tests/client/content/config/tabs/test_models.py +++ b/tests/client/content/config/tabs/test_models.py @@ -22,7 +22,7 @@ def test_model_tables(self, app_server, app_test): for model in at.session_state.model_configs: assert at.text_input(key=f"{model['type']}_{model['id']}_enabled").value == "âšĒ" assert at.text_input(key=f"{model['type']}_{model['id']}_provider").value == model["provider"] - assert at.text_input(key=f"{model['type']}_{model['id']}_server").value == model["url"] + assert at.text_input(key=f"{model['type']}_{model['id']}_server").value == model["api_base"] assert at.button(key=f"{model['type']}_{model['id']}_edit") is not None for model_type in {item["type"] for item in at.session_state.model_configs}: diff --git a/tests/client/content/tools/tabs/test_split_embed.py b/tests/client/content/tools/tabs/test_split_embed.py index 2f1af5bc..eecb360d 100644 --- a/tests/client/content/tools/tabs/test_split_embed.py +++ b/tests/client/content/tools/tabs/test_split_embed.py @@ -30,7 +30,7 @@ def mock_get(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] @@ -42,7 +42,7 @@ def mock_get(endpoint=None, **kwargs): at = app_test(self.ST_FILE) # Mock functions that make external calls to avoid failures - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) # Run the app - this is critical to initialize all widgets! @@ -90,7 +90,7 @@ def mock_get(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] @@ -99,7 +99,7 @@ def mock_get(endpoint=None, **kwargs): monkeypatch.setattr("client.utils.api_call.get", mock_get) # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) # Initialize app_test @@ -136,7 +136,7 @@ def mock_get(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] @@ -145,7 +145,7 @@ def mock_get(endpoint=None, **kwargs): monkeypatch.setattr("client.utils.api_call.get", mock_get) # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) # Initialize app_test @@ -181,7 +181,7 @@ def mock_get(endpoint=None, **kwargs): # Test successful assert True - def test_web_url_validation(self, app_server, app_test, monkeypatch): + def test_web_api_base_validation(self, app_server, app_test, monkeypatch): """Test web URL validation""" assert app_server is not None @@ -193,7 +193,7 @@ def mock_get(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] @@ -202,7 +202,7 @@ def mock_get(endpoint=None, **kwargs): monkeypatch.setattr("client.utils.api_call.get", mock_get) # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) # Initialize app_test @@ -237,7 +237,7 @@ def mock_get(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] @@ -246,7 +246,7 @@ def mock_get(endpoint=None, **kwargs): monkeypatch.setattr("client.utils.api_call.get", mock_get) # Mock functions that make external calls - monkeypatch.setattr("common.functions.is_url_accessible", lambda url: (True, "")) + monkeypatch.setattr("common.functions.is_url_accessible", lambda api_base: (True, "")) monkeypatch.setattr("client.utils.st_common.is_db_configured", lambda: True) # Initialize app_test @@ -307,7 +307,7 @@ def mock_get_response(endpoint=None, **kwargs): "id": "test-model", "type": "embed", "enabled": True, - "url": "http://test.url", + "api_base": "http://test.url", "max_chunk_size": 1000, } ] diff --git a/tests/server/test_endpoints_models.py b/tests/server/test_endpoints_models.py index 8bc12ade..809c9b6d 100644 --- a/tests/server/test_endpoints_models.py +++ b/tests/server/test_endpoints_models.py @@ -107,8 +107,8 @@ def test_models_add_dupl(self, client, auth_headers): "type": "ll", "provider": "openai", "api_key": "test-key", + "api_base": "https://api.openai.com/v1", "openai_compat": True, - "url": "https://api.openai.com/v1", "context_length": 127072, "temperature": 1.0, "max_completion_tokens": 4096, @@ -133,7 +133,7 @@ def test_models_add_dupl(self, client, auth_headers): "enabled": False, "type": "embed", "provider": "huggingface", - "url": "http://127.0.0.1:8080", + "api_base": "http://127.0.0.1:8080", "api_key": "", "openai_compat": True, "max_chunk_size": 512, @@ -144,18 +144,18 @@ def test_models_add_dupl(self, client, auth_headers): ), pytest.param( { - "id": "unreachable_url_model", + "id": "unreachable_api_base_model", "enabled": True, "type": "embed", "provider": "huggingface", - "url": "http://127.0.0.1:112233", + "api_base": "http://127.0.0.1:112233", "api_key": "", "openai_compat": True, "max_chunk_size": 512, }, 201, 422, - id="unreachable_url_model", + id="unreachable_api_base_model", ), ] @@ -165,7 +165,7 @@ def test_model_create(self, client, auth_headers, payload, add_status_code, _, r response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) assert response.status_code == add_status_code if add_status_code == 201: - if request.node.callspec.id == "unreachable_url_model": + if request.node.callspec.id == "unreachable_api_base_model": assert response.json()["enabled"] is False else: assert all(item in response.json().items() for item in payload.items()) From 8c0cfdacda3e7370eac0c0e3fb2d0b8dc36fa1e6 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 27 Aug 2025 16:39:18 +0100 Subject: [PATCH 02/31] starting unit tests --- .../client/content/config/tabs/test_databases.py | 0 tests/{ => integration}/client/content/config/tabs/test_models.py | 0 tests/{ => integration}/client/content/config/tabs/test_oci.py | 0 .../{ => integration}/client/content/config/tabs/test_settings.py | 0 tests/{ => integration}/client/content/test_api_server.py | 0 tests/{ => integration}/client/content/test_chatbot.py | 0 tests/{ => integration}/client/content/test_st_footer.py | 0 tests/{ => integration}/client/content/test_testbed.py | 0 .../client/content/tools/tabs/test_prompt_eng.py | 0 .../client/content/tools/tabs/test_split_embed.py | 0 tests/{ => integration}/server/test_endpoints_chat.py | 0 tests/{ => integration}/server/test_endpoints_databases.py | 0 tests/{ => integration}/server/test_endpoints_embed.py | 0 tests/{ => integration}/server/test_endpoints_health.py | 0 tests/{ => integration}/server/test_endpoints_models.py | 0 tests/{ => integration}/server/test_endpoints_oci.py | 0 tests/{ => integration}/server/test_endpoints_prompts.py | 0 tests/{ => integration}/server/test_endpoints_settings.py | 0 tests/{ => integration}/server/test_endpoints_testbed.py | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ => integration}/client/content/config/tabs/test_databases.py (100%) rename tests/{ => integration}/client/content/config/tabs/test_models.py (100%) rename tests/{ => integration}/client/content/config/tabs/test_oci.py (100%) rename tests/{ => integration}/client/content/config/tabs/test_settings.py (100%) rename tests/{ => integration}/client/content/test_api_server.py (100%) rename tests/{ => integration}/client/content/test_chatbot.py (100%) rename tests/{ => integration}/client/content/test_st_footer.py (100%) rename tests/{ => integration}/client/content/test_testbed.py (100%) rename tests/{ => integration}/client/content/tools/tabs/test_prompt_eng.py (100%) rename tests/{ => integration}/client/content/tools/tabs/test_split_embed.py (100%) rename tests/{ => integration}/server/test_endpoints_chat.py (100%) rename tests/{ => integration}/server/test_endpoints_databases.py (100%) rename tests/{ => integration}/server/test_endpoints_embed.py (100%) rename tests/{ => integration}/server/test_endpoints_health.py (100%) rename tests/{ => integration}/server/test_endpoints_models.py (100%) rename tests/{ => integration}/server/test_endpoints_oci.py (100%) rename tests/{ => integration}/server/test_endpoints_prompts.py (100%) rename tests/{ => integration}/server/test_endpoints_settings.py (100%) rename tests/{ => integration}/server/test_endpoints_testbed.py (100%) diff --git a/tests/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py similarity index 100% rename from tests/client/content/config/tabs/test_databases.py rename to tests/integration/client/content/config/tabs/test_databases.py diff --git a/tests/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py similarity index 100% rename from tests/client/content/config/tabs/test_models.py rename to tests/integration/client/content/config/tabs/test_models.py diff --git a/tests/client/content/config/tabs/test_oci.py b/tests/integration/client/content/config/tabs/test_oci.py similarity index 100% rename from tests/client/content/config/tabs/test_oci.py rename to tests/integration/client/content/config/tabs/test_oci.py diff --git a/tests/client/content/config/tabs/test_settings.py b/tests/integration/client/content/config/tabs/test_settings.py similarity index 100% rename from tests/client/content/config/tabs/test_settings.py rename to tests/integration/client/content/config/tabs/test_settings.py diff --git a/tests/client/content/test_api_server.py b/tests/integration/client/content/test_api_server.py similarity index 100% rename from tests/client/content/test_api_server.py rename to tests/integration/client/content/test_api_server.py diff --git a/tests/client/content/test_chatbot.py b/tests/integration/client/content/test_chatbot.py similarity index 100% rename from tests/client/content/test_chatbot.py rename to tests/integration/client/content/test_chatbot.py diff --git a/tests/client/content/test_st_footer.py b/tests/integration/client/content/test_st_footer.py similarity index 100% rename from tests/client/content/test_st_footer.py rename to tests/integration/client/content/test_st_footer.py diff --git a/tests/client/content/test_testbed.py b/tests/integration/client/content/test_testbed.py similarity index 100% rename from tests/client/content/test_testbed.py rename to tests/integration/client/content/test_testbed.py diff --git a/tests/client/content/tools/tabs/test_prompt_eng.py b/tests/integration/client/content/tools/tabs/test_prompt_eng.py similarity index 100% rename from tests/client/content/tools/tabs/test_prompt_eng.py rename to tests/integration/client/content/tools/tabs/test_prompt_eng.py diff --git a/tests/client/content/tools/tabs/test_split_embed.py b/tests/integration/client/content/tools/tabs/test_split_embed.py similarity index 100% rename from tests/client/content/tools/tabs/test_split_embed.py rename to tests/integration/client/content/tools/tabs/test_split_embed.py diff --git a/tests/server/test_endpoints_chat.py b/tests/integration/server/test_endpoints_chat.py similarity index 100% rename from tests/server/test_endpoints_chat.py rename to tests/integration/server/test_endpoints_chat.py diff --git a/tests/server/test_endpoints_databases.py b/tests/integration/server/test_endpoints_databases.py similarity index 100% rename from tests/server/test_endpoints_databases.py rename to tests/integration/server/test_endpoints_databases.py diff --git a/tests/server/test_endpoints_embed.py b/tests/integration/server/test_endpoints_embed.py similarity index 100% rename from tests/server/test_endpoints_embed.py rename to tests/integration/server/test_endpoints_embed.py diff --git a/tests/server/test_endpoints_health.py b/tests/integration/server/test_endpoints_health.py similarity index 100% rename from tests/server/test_endpoints_health.py rename to tests/integration/server/test_endpoints_health.py diff --git a/tests/server/test_endpoints_models.py b/tests/integration/server/test_endpoints_models.py similarity index 100% rename from tests/server/test_endpoints_models.py rename to tests/integration/server/test_endpoints_models.py diff --git a/tests/server/test_endpoints_oci.py b/tests/integration/server/test_endpoints_oci.py similarity index 100% rename from tests/server/test_endpoints_oci.py rename to tests/integration/server/test_endpoints_oci.py diff --git a/tests/server/test_endpoints_prompts.py b/tests/integration/server/test_endpoints_prompts.py similarity index 100% rename from tests/server/test_endpoints_prompts.py rename to tests/integration/server/test_endpoints_prompts.py diff --git a/tests/server/test_endpoints_settings.py b/tests/integration/server/test_endpoints_settings.py similarity index 100% rename from tests/server/test_endpoints_settings.py rename to tests/integration/server/test_endpoints_settings.py diff --git a/tests/server/test_endpoints_testbed.py b/tests/integration/server/test_endpoints_testbed.py similarity index 100% rename from tests/server/test_endpoints_testbed.py rename to tests/integration/server/test_endpoints_testbed.py From 1ce73b81c728865d6386b4617ad776e82ba96fe8 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 27 Aug 2025 17:23:11 +0100 Subject: [PATCH 03/31] remove unused import --- src/launch_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/launch_client.py b/src/launch_client.py index 0ee360c4..4f587346 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -7,7 +7,6 @@ """ # spell-checker:ignore streamlit, scriptrunner -import asyncio import os from uuid import uuid4 From bd23d085108020e4369e78cc8b8280b611c981b0 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 27 Aug 2025 17:36:32 +0100 Subject: [PATCH 04/31] starting litellm client --- src/server/api/utils/models.py | 29 +++++++++++++++++++++++++++ tests/unit/server/api/utils/models.py | 0 2 files changed, 29 insertions(+) create mode 100644 tests/unit/server/api/utils/models.py diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index acc9e6c4..670ac319 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -4,6 +4,7 @@ """ # spell-checker:ignore ollama pplx huggingface genai giskard litellm ocigenai +from litellm import get_supported_openai_params from openai import OpenAI from langchain_core.language_models.chat_models import BaseChatModel @@ -79,6 +80,34 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model return genai_models +def get_litellm_client( + model_config: dict, oci_config: schema.OracleCloudSettings = None, giskard: bool = False +) -> BaseChatModel: + """Establish client""" + logger.debug("Model Client: %s; OCI Config: %s; Giskard: %s", model_config, oci_config, giskard) + + try: + defined_model = core_models.get_model( + model_id=model_config["id"], + include_disabled=False, + ).model_dump() + except core_models.UnknownModelError: + return None + + # Merge configurations, skipping None values + full_model_config = {**defined_model, **{k: v for k, v in model_config.items() if v is not None}} + print(f"*********** {full_model_config}") + + # Determine provider and model name + provider = "openai" if model_config["provider"] == "openai_compatible" else model_config["provider"] + model_name = f"{provider}/{model_config['id']}" + + # Get supported parameters and initialize config + supported_params = get_supported_openai_params(model=model_name) + litellm_config = {k: full_model_config[k] for k in supported_params if k in full_model_config and full_model_config[k] is not None} + litellm_config["model"] = model_name + litellm_config["api_base"] = full_model_config["api_base"] + print(f"*********** {litellm_config}") def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giskard: bool = False) -> BaseChatModel: """Retrieve model configuration""" diff --git a/tests/unit/server/api/utils/models.py b/tests/unit/server/api/utils/models.py new file mode 100644 index 00000000..e69de29b From 2fb3dc72c817c1435c2caad8a966dcd03830f873 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Wed, 27 Aug 2025 17:42:09 +0100 Subject: [PATCH 05/31] dedupe code --- src/server/bootstrap/models.py | 63 ++++++++++------------------------ 1 file changed, 18 insertions(+), 45 deletions(-) diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 6f92b9d2..bd69f66f 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -22,6 +22,20 @@ def main() -> list[Model]: """Define example Model Support""" logger.debug("*** Bootstrapping Models - Start") + def update_env_var(model: Model, provider: str, model_key: str, env_var: str): + if model.get("provider") != provider: + return + + new_value = os.environ.get(env_var) + if not new_value: + return + + old_value = model.get(model_key) + if old_value != new_value: + logger.debug("Overriding '%s' for model '%s' with %s environment variable", model_key, model.id, env_var) + model[model_key] = new_value + logger.debug("Model '%s' updated via environment variable overrides.", model.id) + models_list = [ { "id": "command-r", @@ -201,51 +215,10 @@ def values_differ(a, b): # Override with OS env vars (by API type) for model in models_list: - provider = model.get("provider", "") - model_id = model.get("id", "") - overridden = False - - if provider == "cohere" and os.getenv("COHERE_API_KEY"): - old_api_key = model.get("api_key", "") - new_api_key = os.environ["COHERE_API_KEY"] - if old_api_key != new_api_key: - # Exposes key if in DEBUG - logger.debug("Overriding 'api_key' for model '%s' with COHERE_API_KEY environment variable", model_id) - model["api_key"] = new_api_key - overridden = True - model["enabled"] = True - - elif provider == "oci" and os.getenv("OCI_GENAI_SERVICE_ENDPOINT"): - old_url = model.get("api_base", "") - new_url = os.environ["OCI_GENAI_SERVICE_ENDPOINT"] - if old_url != new_url: - logger.info( - "Overriding 'url' for model '%s' with OCI_GENAI_SERVICE_ENDPOINT environment variable", model_id - ) - model["api_base"] = new_url - overridden = True - model["enabled"] = True - - elif provider == "ollama" and os.getenv("ON_PREM_OLLAMA_URL"): - old_url = model.get("api_base", "") - new_url = os.environ["ON_PREM_OLLAMA_URL"] - if old_url != new_url: - logger.info("Overriding 'url' for model '%s' with ON_PREM_OLLAMA_URL environment variable", model_id) - model["api_base"] = new_url - overridden = True - model["enabled"] = True - - elif provider == "huggingface" and os.getenv("ON_PREM_HF_URL"): - old_url = model.get("api_base", "") - new_url = os.environ["ON_PREM_HF_URL"] - if old_url != new_url: - logger.info("Overriding 'url' for model '%s' with ON_PREM_HF_URL environment variable", model_id) - model["api_base"] = new_url - overridden = True - model["enabled"] = True - - if overridden: - logger.debug("Model '%s' updated via environment variable overrides.", model_id) + update_env_var(model, "cohere", "api_key", "COHERE_API_KEY") + update_env_var(model, "oci", "api_base", "OCI_GENAI_SERVICE_ENDPOINT") + update_env_var(model, "ollama", "api_base", "ON_PREM_OLLAMA_URL") + update_env_var(model, "huggingface", "api_base", "ON_PREM_HF_URL") # Check URL accessible for enabled models and disable if not: url_access_cache = {} From 0dd23926efda6b2226507ebf2b3995e23add6def Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 29 Aug 2025 11:00:20 +0100 Subject: [PATCH 06/31] Switched to LiteLLM --- src/client/content/testbed.py | 4 +- src/common/schema.py | 50 +------------- src/server/api/core/models.py | 6 -- src/server/api/utils/models.py | 120 +++++++++++---------------------- src/server/api/utils/oci.py | 5 +- src/server/api/v1/chat.py | 7 +- src/server/bootstrap/models.py | 15 +---- 7 files changed, 51 insertions(+), 156 deletions(-) diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index 8e0c5d66..15f73aaf 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -248,7 +248,7 @@ def main() -> None: # If there is no eligible (OpenAI Compat.) LL Models; then disable ALL functionality ll_models_enabled = st_common.enabled_models_lookup("ll") - available_ll_models = [key for key, value in ll_models_enabled.items() if value.get("openai_compat")] + available_ll_models = [key for key, value in ll_models_enabled.items()] if not available_ll_models: st.error( "No OpenAI compatible language models are configured and/or enabled. Disabling Testing Framework.", @@ -261,7 +261,7 @@ def main() -> None: # If there is no eligible (OpenAI Compat.) Embedding Model; disable Generate Test Set gen_testset_disabled = False embed_models_enabled = st_common.enabled_models_lookup("embed") - available_embed_models = [key for key, value in embed_models_enabled.items() if value.get("openai_compat")] + available_embed_models = [key for key, value in embed_models_enabled.items()] if not available_embed_models: st.warning( "No OpenAI compatible embedding models are configured and/or enabled. Disabling Test Set Generation.", diff --git a/src/common/schema.py b/src/common/schema.py index 9fac3584..e5b062d7 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -6,7 +6,7 @@ # spell-checker:ignore deepseek groq huggingface mistralai ocigenai vertexai import time -from typing import Optional, Literal, Union, get_args, Any +from typing import Optional, Literal, get_args, Any from pydantic import BaseModel, Field, PrivateAttr, model_validator from langchain_core.messages import ChatMessage @@ -152,7 +152,6 @@ class Model(ModelAccess, LanguageModelParameters, EmbeddingModelParameters): ) type: Literal["ll", "embed", "re-rank"] = Field(..., description="Type of Model.") provider: str = Field(..., min_length=1, description="Model Provider.", examples=["openai", "anthropic", "ollama"]) - openai_compat: bool = Field(default=True, description="Is the API OpenAI compatible?") @model_validator(mode="after") def check_provider(self): @@ -362,53 +361,6 @@ def recursive_dump_excluding_marked(cls, obj: Any, incl_sensitive: bool, incl_re ##################################################### # Completions ##################################################### -class ChatLogprobs(BaseModel): - """Log probability information for the choice.""" - - content: Optional[dict[str, Union[str, int, dict]]] = Field( - default=None, description="A list of message content tokens with log probability information." - ) - refusal: Optional[dict[str, Union[str, int, dict]]] = Field( - default=None, description="A list of message refusal tokens with log probability information." - ) - - -class ChatChoices(BaseModel): - """A list of chat completion choices.""" - - index: int = Field(description="The index of the choice in the list of choices.") - message: ChatMessage = Field(descriptions="A chat completion message generated by the model.") - finish_reason: Literal["stop", "length", "content_filter", "tool_calls"] = Field( - description=( - "The reason the model stopped generating tokens. " - "This will be stop if the model hit a natural stop point or a provided stop sequence, " - "length if the maximum number of tokens specified in the request was reached, " - "content_filter if content was omitted due to a flag from our content filters, " - "tool_calls if the model called a tool." - ) - ) - logprobs: Optional[ChatLogprobs] = Field(default=None, description="Log probability information for the choice.") - - -class ChatUsage(BaseModel): - """Usage statistics for the completion request.""" - - prompt_tokens: int = Field(description="Number of tokens in the prompt.") - completion_tokens: int = Field(description="Number of tokens in the generated completion.") - total_tokens: int = Field(description="Total number of tokens used in the request (prompt + completion).") - - -class ChatResponse(BaseModel): - """Represents a chat completion response returned by model, based on the provided input.""" - - id: str = Field(description="A unique identifier for the chat completion.") - choices: list[ChatChoices] = Field(description="A list of chat completion choices.") - created: int = Field(description="The Unix timestamp (in seconds) of when the chat completion was created.") - model: str = Field(description="The model used for the chat completion.") - object: str = Field(default="chat.completion", description="The model used for the chat completion.") - usage: Optional[ChatUsage] = Field(default=None, description="Usage statistics for the completion request.") - - class ChatRequest(LanguageModelParameters): """ Request Body (inherits LanguageModelParameters) diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index b9cdda89..77c0820c 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -73,12 +73,6 @@ def create_model(model: Model, check_url: bool = True) -> Model: if any(d.id == model.id for d in model_objects): raise ExistsModelError(f"Model: {model.id} already exists.") - if not model.openai_compat: - openai_compat = next( - (model_config.openai_compat for model_config in model_objects if model_config.provider == model.provider), - False, - ) - model.openai_compat = openai_compat if check_url and model.api_base and not is_url_accessible(model.api_base)[0]: model.enabled = False diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 670ac319..a6e40909 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -4,6 +4,8 @@ """ # spell-checker:ignore ollama pplx huggingface genai giskard litellm ocigenai +from urllib.parse import urlparse + from litellm import get_supported_openai_params from openai import OpenAI @@ -55,6 +57,11 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model genai_models = [] for model in region_models: + if model["vendor"] == "cohere": + # Note that we can enable this if the GenAI endpoint supports OpenAI compat + # https://docs.cohere.com/docs/compatibility-api + logger.info("Skipping %s; no support for OCI GenAI cohere models", model["model_name"]) + continue model_dict = {} model_dict["provider"] = "oci" if "CHAT" in model["capabilities"]: @@ -69,8 +76,6 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model model_dict["id"] = model["model_name"] model_dict["enabled"] = True model_dict["api_base"] = f"https://inference.generativeai.{config.genai_region}.oci.oraclecloud.com" - # if model["vendor"] == "cohere": - model_dict["openai_compat"] = False # Create the Model try: new_model = schema.Model(**model_dict) @@ -80,15 +85,16 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model return genai_models + def get_litellm_client( model_config: dict, oci_config: schema.OracleCloudSettings = None, giskard: bool = False -) -> BaseChatModel: +) -> dict: """Establish client""" logger.debug("Model Client: %s; OCI Config: %s; Giskard: %s", model_config, oci_config, giskard) try: defined_model = core_models.get_model( - model_id=model_config["id"], + model_id=model_config["model"], include_disabled=False, ).model_dump() except core_models.UnknownModelError: @@ -96,85 +102,39 @@ def get_litellm_client( # Merge configurations, skipping None values full_model_config = {**defined_model, **{k: v for k, v in model_config.items() if v is not None}} - print(f"*********** {full_model_config}") - + # Determine provider and model name - provider = "openai" if model_config["provider"] == "openai_compatible" else model_config["provider"] - model_name = f"{provider}/{model_config['id']}" + provider = "openai" if full_model_config["provider"] == "openai_compatible" else full_model_config["provider"] + model_name = f"{provider}/{full_model_config['id']}" # Get supported parameters and initialize config supported_params = get_supported_openai_params(model=model_name) - litellm_config = {k: full_model_config[k] for k in supported_params if k in full_model_config and full_model_config[k] is not None} - litellm_config["model"] = model_name - litellm_config["api_base"] = full_model_config["api_base"] - print(f"*********** {litellm_config}") - -def get_client(model_config: dict, oci_config: schema.OracleCloudSettings, giskard: bool = False) -> BaseChatModel: - """Retrieve model configuration""" - logger.debug("Model Client: %s; OCI Config: %s; Giskard: %s", model_config, oci_config, giskard) - try: - defined_model = core_models.get_model( - model_id=model_config["model"], - include_disabled=False, - ).model_dump() - except core_models.UnknownModelError: - return None - - full_model_config = {**defined_model, **{k: v for k, v in model_config.items() if v is not None}} - client = None - provider = full_model_config["provider"] - if full_model_config["type"] == "ll" and not giskard: - common_params = { - k: full_model_config.get(k) for k in ["frequency_penalty", "presence_penalty", "top_p", "streaming"] - } - if provider != "oci": - kwargs = { - "model_provider": "openai" if provider == "openai_compatible" else provider, - "model": full_model_config["id"], - "base_url": full_model_config["api_base"], - "temperature": full_model_config["temperature"], - "max_tokens": full_model_config["max_completion_tokens"], - **common_params, + litellm_config = { + k: full_model_config[k] + for k in supported_params + if k in full_model_config and full_model_config[k] is not None + } + if "cohere" in model_name: + # Ensure we use the OpenAI compatible endpoint + parsed = urlparse(full_model_config.get("api_base")) + scheme = parsed.scheme or "https" + netloc = "api.cohere.ai" + # Always force the path + path = "/compatibility/v1" + full_model_config["api_base"] = f"{scheme}://{netloc}{path}" + + litellm_config.update({"model": model_name, "api_base": full_model_config.get("api_base")}) + + if provider == "oci": + litellm_config.update( + { + "oci_user": oci_config.user, + "oci_fingerprint": oci_config.fingerprint, + "oci_tenancy": oci_config.tenancy, + "oci_region": oci_config.genai_region, + "oci_key_file": oci_config.key_file, + "oci_compartment_id": oci_config.genai_compartment_id, } + ) - if full_model_config.get("api_key"): # only add if present - kwargs["api_key"] = full_model_config["api_key"] - - client = init_chat_model(**kwargs) - else: - client = ChatOCIGenAI( - model_id=full_model_config["id"], - client=util_oci.init_genai_client(oci_config), - compartment_id=oci_config.genai_compartment_id, - model_kwargs={ - (k if k != "max_completion_tokens" else "max_tokens"): v - for k, v in common_params.items() - if k not in {"streaming"} - }, - ) - - if full_model_config["type"] == "embed" and not giskard: - if provider != "oci": - kwargs = { - "provider": "openai" if provider == "openai_compatible" else provider, - "model": full_model_config["id"], - "base_url": full_model_config["api_base"], - } - if full_model_config.get("api_key"): # only add if set - kwargs["api_key"] = full_model_config["api_key"] - client = init_embeddings(**kwargs) - else: - client = OCIGenAIEmbeddings( - model_id=full_model_config["id"], - client=util_oci.init_genai_client(oci_config), - compartment_id=oci_config.genai_compartment_id, - ) - - if giskard: - logger.debug("Creating Giskard Client") - giskard_key = full_model_config["api_key"] or "giskard" - _client = OpenAI(api_key=giskard_key, base_url=full_model_config["api_base"]) - client = OpenAIClient(model=full_model_config["id"], client=_client) - - logger.debug("Configured Client: %s", vars(client)) - return client + return litellm_config diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 5d042b48..76f18b2d 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -9,7 +9,6 @@ import urllib3.exceptions import oci - from server.api.core.oci import OciException from common.schema import OracleCloudSettings @@ -66,8 +65,8 @@ def init_client( with open(config_json["security_token_file"], "r", encoding="utf-8") as f: token = f.read() private_key = oci.signer.load_private_key_from_file(config_json["key_file"]) - signer = oci.auth.signers.SecurityTokenSigner(token, private_key) - client = client_type(config={"region": config_json["region"]}, signer=signer, **client_kwargs) + sec_token_signer = oci.auth.signers.SecurityTokenSigner(token, private_key) + client = client_type(config={"region": config_json["region"]}, signer=sec_token_signer, **client_kwargs) else: logger.info("OCI Authentication as Standard") client = client_type(config_json, **client_kwargs) diff --git a/src/server/api/v1/chat.py b/src/server/api/v1/chat.py index 33b476a9..47bc5320 100644 --- a/src/server/api/v1/chat.py +++ b/src/server/api/v1/chat.py @@ -2,10 +2,11 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore selectai +# spell-checker:ignore selectai litellm from fastapi import APIRouter, Header from fastapi.responses import StreamingResponse +from litellm import ModelResponse from langchain_core.messages import ( AnyMessage, @@ -31,11 +32,11 @@ @auth.post( "/completions", description="Submit a message for full completion.", - response_model=schema.ChatResponse, + response_model=ModelResponse, ) async def chat_post( request: schema.ChatRequest, client: schema.ClientIdType = Header(default="server") -) -> schema.ChatResponse: +) -> ModelResponse: """Full Completion Requests""" last_message = None async for chunk in chat.completion_generator(client, request, "completions"): diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index bd69f66f..5ed65389 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -43,8 +43,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "cohere", "api_key": os.environ.get("COHERE_API_KEY", default=""), - "openai_compat": False, - "api_base": "https://api.cohere.ai", + "api_base": "https://api.cohere.ai/compatibility/v1", "context_length": 127072, "temperature": 0.3, "max_completion_tokens": 4096, @@ -56,7 +55,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "openai", "api_key": os.environ.get("OPENAI_API_KEY", default=""), - "openai_compat": True, "api_base": "https://api.openai.com/v1", "context_length": 127072, "temperature": 1.0, @@ -69,7 +67,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "perplexity", "api_key": os.environ.get("PPLX_API_KEY", default=""), - "openai_compat": True, "api_base": "https://api.perplexity.ai", "context_length": 127072, "temperature": 0.2, @@ -82,7 +79,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "openai_compatible", "api_key": "", - "openai_compat": True, "api_base": "http://localhost:1234/v1", "context_length": 131072, "temperature": 1.0, @@ -95,7 +91,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "ollama", "api_key": "", - "openai_compat": True, "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "context_length": 131072, "temperature": 1.0, @@ -109,7 +104,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "ollama", "api_key": "", - "openai_compat": True, "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "context_length": 131072, "temperature": 1.0, @@ -123,7 +117,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "provider": "huggingface", "api_base": os.environ.get("ON_PREM_HF_URL", default="http://127.0.0.1:8080"), "api_key": "", - "openai_compat": True, "max_chunk_size": 512, }, { @@ -133,7 +126,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "provider": "openai_compatible", "api_base": "https://api.openai.com/v1", "api_key": os.environ.get("OPENAI_API_KEY", default=""), - "openai_compat": True, "max_chunk_size": 8191, }, { @@ -141,9 +133,8 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "enabled": os.getenv("COHERE_API_KEY") is not None, "type": "embed", "provider": "cohere", - "api_base": "https://api.cohere.ai", + "api_base": "https://api.cohere.ai/compatibility/v1", "api_key": os.environ.get("COHERE_API_KEY", default=""), - "openai_compat": False, "max_chunk_size": 512, }, { @@ -153,7 +144,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "provider": "openai_compatible", "api_base": "http://localhost:1234/v1", "api_key": "", - "openai_compat": True, "max_chunk_size": 8192, }, { @@ -164,7 +154,6 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "provider": "ollama", "api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"), "api_key": "", - "openai_compat": True, "max_chunk_size": 8192, }, ] From d1358c6d7553e3c08cb0c3c392e03c515cb68abd Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 29 Aug 2025 11:19:28 +0100 Subject: [PATCH 07/31] history and streaming --- src/server/agents/chatbot.py | 438 +++++------------------------------ 1 file changed, 54 insertions(+), 384 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index 5ffe5644..dc7c73b6 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -2,425 +2,95 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore langgraph, oraclevs, checkpointer, ainvoke -# spell-checker:ignore vectorstore, vectorstores, oraclevs, mult, selectai - -from datetime import datetime, timezone -from typing import Literal -import json -import copy -import decimal - -from langchain_core.documents.base import Document -from langchain_core.messages import SystemMessage, ToolMessage -from langchain_core.output_parsers import StrOutputParser, PydanticOutputParser -from langchain_core.prompts import PromptTemplate -from langchain_core.runnables import RunnableConfig -from langchain_community.vectorstores.oraclevs import OracleVS +# spell-checker:ignore litellm checkpointer acompletion astream from langgraph.checkpoint.memory import MemorySaver -from langgraph.graph import MessagesState, StateGraph, START, END +from langgraph.config import get_stream_writer +from langgraph.graph import StateGraph, START, END, MessagesState +from langchain_core.messages import AIMessage -from pydantic import BaseModel, Field +from langchain_core.runnables import RunnableConfig +from litellm import acompletion -from server.api.core.databases import execute_sql -from common.schema import ChatResponse, ChatUsage, ChatChoices, ChatMessage from common import logging_config logger = logging_config.logging.getLogger("server.agents.chatbot") -############################################################################# -# AGENT STATE -############################################################################# -class AgentState(MessagesState): +class OptimizerState(MessagesState): """Establish our Agent State Machine""" - logger.info("Establishing Agent State") - final_response: ChatResponse # OpenAI Response - cleaned_messages: list # Messages w/o VS Results - context_input: str # Contextualized User Input - documents: dict # VectorStore documents + final_response: dict # OpenAI Response ############################################################################# # Functions ############################################################################# -def get_messages(state: AgentState, config: RunnableConfig) -> list: +def get_messages(state: OptimizerState, config: RunnableConfig) -> list: """Return a list of messages that will be passed to the model for completion Filter out old VS documents to avoid blowing-out the context window Leave the state as is for GUI functionality""" use_history = config["metadata"]["use_history"] - # If user decided for no history, only take the last message - state_messages = state["messages"] if use_history else state["messages"][-1:] - - messages = [] - for msg in state_messages: - if isinstance(msg, SystemMessage): - continue - if isinstance(msg, ToolMessage): - if messages: # Check if there are any messages in the list - messages.pop() # Remove the last appended message - continue - messages.append(msg) + state_messages = state.get("messages", []) + if state_messages: + # If user decided for no history, only take the last message + state_messages = state_messages if use_history else state_messages[-1:] - # insert the system prompt; remaining messages cleaned - if config["metadata"]["sys_prompt"].prompt: - messages.insert(0, SystemMessage(content=config["metadata"]["sys_prompt"].prompt)) + prompt_messages = [{"role": "user", "content": m.content} for m in state_messages] - return messages - - -def document_formatter(rag_context) -> str: - """Extract the Vector Search Documents and format into a string""" - logger.info("Extracting chunks from Vector Search Retrieval") - logger.debug("Vector Search Context: %s", rag_context) - chunks = "\n\n".join([doc["page_content"] for doc in rag_context]) - return chunks - - -class DecimalEncoder(json.JSONEncoder): - """Used with json.dumps to encode decimals""" - - def default(self, o): - if isinstance(o, decimal.Decimal): - return str(o) - return super().default(o) + return prompt_messages ############################################################################# # NODES and EDGES ############################################################################# -def respond(state: AgentState, config: RunnableConfig) -> ChatResponse: - """Respond in OpenAI Compatible return""" - ai_message = state["messages"][-1] - logger.debug("Formatting Response to OpenAI compatible message: %s", repr(ai_message)) - model_id = config["metadata"]["model_id"] - if "model_id" in ai_message.response_metadata: - ai_metadata = ai_message - else: - ai_metadata = state["messages"][1] - logger.debug("Using Metadata from: %s", repr(ai_metadata)) - - finish_reason = ai_metadata.response_metadata.get("finish_reason", "stop") - if finish_reason == "COMPLETE": - finish_reason = "stop" - elif finish_reason == "MAX_TOKENS": - finish_reason = "length" - - openai_response = ChatResponse( - id=ai_message.id, - created=int(datetime.now(timezone.utc).timestamp()), - model=model_id, - usage=ChatUsage( - prompt_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("prompt_tokens", -1), - completion_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("completion_tokens", -1), - total_tokens=ai_metadata.response_metadata.get("token_usage", {}).get("total_tokens", -1), - ), - choices=[ - ChatChoices( - index=0, - message=ChatMessage( - role="ai", - content=ai_message.content, - additional_kwargs=ai_metadata.additional_kwargs, - response_metadata=ai_metadata.response_metadata, - ), - finish_reason=finish_reason, - logprobs=None, - ) - ], - ) - return {"final_response": openai_response} - - -def vs_retrieve(state: AgentState, config: RunnableConfig) -> AgentState: - """Search and return information using Vector Search""" - ## Note that this should be a tool call; but some models (Perplexity/OCI GenAI) - ## have limited or no tools support. Instead we'll call as part of the pipeline - ## and fake a tools call. This can be later reverted to a tool without much code change. - logger.info("Perform Vector Search") - # Take our contextualization prompt and reword the question - # before doing the vector search; do only if history is turned on - history = copy.deepcopy(state["cleaned_messages"]) - retrieve_question = history.pop().content - if config["metadata"]["use_history"] and config["metadata"]["ctx_prompt"].prompt and len(history) > 1: - model = config["configurable"].get("ll_client", None) - ctx_template = """ - {ctx_prompt} - Here is the context and history: - ------- - {history} - ------- - Here is the user input: - ------- - {question} - ------- - Return ONLY the rephrased query without any explanation or additional text. - """ - rephrase = PromptTemplate( - template=ctx_template, - input_variables=["ctx_prompt", "history", "question"], - ) - chain = rephrase | model - logger.info("Retrieving Rephrased Input for VS") - result = chain.invoke( - { - "ctx_prompt": config["metadata"]["ctx_prompt"].prompt, - "history": history, - "question": retrieve_question, - } - ) - if result.content != retrieve_question: - logger.info("**** Replacing User Question: %s with contextual one: %s", retrieve_question, result.content) - retrieve_question = result.content - try: - logger.info("Connecting to VectorStore") - db_conn = config["configurable"]["db_conn"] - embed_client = config["configurable"]["embed_client"] - vector_search = config["metadata"]["vector_search"] - logger.info("Initializing Vector Store: %s", vector_search.vector_store) - try: - vectorstore = OracleVS(db_conn, embed_client, vector_search.vector_store, vector_search.distance_metric) - except Exception as ex: - logger.exception("Failed to initialize the Vector Store") - raise ex - - try: - search_type = vector_search.search_type - search_kwargs = {"k": vector_search.top_k} - - if search_type == "Similarity": - retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs=search_kwargs) - elif search_type == "Similarity Score Threshold": - search_kwargs["score_threshold"] = vector_search.score_threshold - retriever = vectorstore.as_retriever( - search_type="similarity_score_threshold", search_kwargs=search_kwargs - ) - elif search_type == "Maximal Marginal Relevance": - search_kwargs.update( - { - "fetch_k": vector_search.fetch_k, - "lambda_mult": vector_search.lambda_mult, - } - ) - retriever = vectorstore.as_retriever(search_type="mmr", search_kwargs=search_kwargs) - else: - raise ValueError(f"Unsupported search_type: {search_type}") - logger.info("Invoking retriever on: %s", retrieve_question) - documents = retriever.invoke(retrieve_question) - except Exception as ex: - logger.exception("Failed to perform Oracle Vector Store retrieval") - raise ex - except (AttributeError, KeyError, TypeError) as ex: - documents = Document( - id="DocumentException", page_content="I'm sorry, I think you found a bug!", metadata={"source": f"{ex}"} - ) - documents_dict = [vars(doc) for doc in documents] - logger.info("Found Documents: %i", len(documents_dict)) - return {"context_input": retrieve_question, "documents": documents_dict} - - -def grade_documents(state: AgentState, config: RunnableConfig) -> Literal["generate_response", "vs_generate"]: - """Determines whether the retrieved documents are relevant to the question.""" - logger.info("Grading Vector Search Response using %i retrieved documents", len(state["documents"])) - - # Data model - class Grade(BaseModel): - """Binary score for relevance check.""" - - binary_score: str = Field(description="Relevance score 'yes' or 'no'") - - if config["metadata"]["vector_search"].grading: - # LLM (Bound to Tool) - model = config["configurable"].get("ll_client", None) - try: - llm_with_grader = model.with_structured_output(Grade) - except NotImplementedError: - logger.error("Model does not support structured output") - parser = PydanticOutputParser(pydantic_object=Grade) - llm_with_grader = model | parser - - # Prompt - grade_template = """ - You are a Grader assessing the relevance of retrieved text to the user's input. - You MUST respond with a only a binary score of 'yes' or 'no'. - If you DO find ANY relevant retrieved text to the user's input, return 'yes' immediately and stop grading. - If you DO NOT find relevant retrieved text to the user's input, return 'no'. - Here is the user input: - ------- - {question} - ------- - Here is the retrieved text: - ------- - {context} - """ - grader = PromptTemplate( - template=grade_template, - input_variables=["context", "question"], - ) - documents = document_formatter(state["documents"]) - question = state["context_input"] - logger.debug("Grading %s against Documents: %s", question, documents) - chain = grader | llm_with_grader - try: - scored_result = chain.invoke({"question": question, "context": documents}) - logger.info("Grading completed.") - score = scored_result.binary_score - except Exception: - logger.error("LLM is not returning binary score in grader; marking all results relevant.") - score = "yes" - else: - logger.info("Vector Search Grading disabled; marking all results relevant.") - score = "yes" - - logger.info("Grading Decision: Vector Search Relevant: %s", score) - if score == "yes": - # This is where we fake a tools response before the completion. - logger.debug("Creating ToolsMessage Documents: %s", state["documents"]) - logger.debug("Creating ToolsMessage ContextQ: %s", state["context_input"]) +async def stream_completion(state: OptimizerState, config: RunnableConfig | None = None): + """LiteLLM streaming wrapper""" + writer = get_stream_writer() + full_response = [] + collected_content = [] - state["messages"].append( - ToolMessage( - content=json.dumps([state["documents"], state["context_input"]], cls=DecimalEncoder), - name="oraclevs_tool", - tool_call_id="tool_placeholder", - ) - ) - logger.debug("ToolsMessage Created") - return "vs_generate" - else: - return "generate_response" - - -async def vs_generate(state: AgentState, config: RunnableConfig) -> None: - """Generate answer when Vector Search enabled; modify state with response""" - logger.info("Generating Vector Search Response") - - # Generate prompt with Vector Search context - generate_template = "SystemMessage(content='{sys_prompt}\n {context}'), HumanMessage(content='{question}')" - prompt_template = PromptTemplate( - template=generate_template, - input_variables=["sys_prompt", "context", "question"], - ) - - # Chain and Run - llm = config["configurable"].get("ll_client", None) - generate_chain = prompt_template | llm | StrOutputParser() - documents = document_formatter(state["documents"]) - logger.debug("Completing: '%s' against relevant VectorStore documents", state["context_input"]) - chain = { - "sys_prompt": config["metadata"]["sys_prompt"].prompt, - "question": state["context_input"], - "context": documents, - } - - response = await generate_chain.ainvoke(chain) - return {"messages": ("assistant", response)} - - -async def selectai_generate(state: AgentState, config: RunnableConfig) -> None: - """Generate answer when SelectAI enabled; modify state with response""" - history = copy.deepcopy(state["cleaned_messages"]) - selectai_prompt = history.pop().content - - logger.info("Generating SelectAI Response on %s", selectai_prompt) - sql = """ - SELECT DBMS_CLOUD_AI.GENERATE( - prompt => :query, - profile_name => :profile, - action => :action) - FROM dual - """ - binds = { - "query": selectai_prompt, - "profile": config["metadata"]["selectai"].profile, - "action": config["metadata"]["selectai"].action, - } - # Execute the SQL using the connection - db_conn = config["configurable"]["db_conn"] - try: - completion = execute_sql(db_conn, sql, binds) - except Exception as ex: - logger.error("SelectAI has hit an issue: %s", ex) - completion = [{sql: "I'm sorry, I have no information related to your query."}] - # Response will be [{sql:, completion}]; return the completion - logger.debug("SelectAI Responded: %s", completion) - response = list(completion[0].values())[0] - - return {"messages": ("assistant", response)} - - -async def agent(state: AgentState, config: RunnableConfig) -> AgentState: - """Invokes the chatbot with messages to be used""" - logger.debug("Initializing Agent") - messages = get_messages(state, config) - return {"cleaned_messages": messages} - - -def use_tool(_, config: RunnableConfig) -> Literal["selectai_generate", "vs_retrieve", "generate_response"]: - """Conditional edge to determine if using SelectAI, Vector Search or not""" - selectai_enabled = config["metadata"]["selectai"].enabled - if selectai_enabled: - logger.info("Invoking Chatbot with SelectAI: %s", selectai_enabled) - return "selectai_generate" - - enabled = config["metadata"]["vector_search"].enabled - if enabled: - logger.info("Invoking Chatbot with Vector Search: %s", enabled) - return "vs_retrieve" - - return "generate_response" - - -async def generate_response(state: AgentState, config: RunnableConfig) -> AgentState: - """Invoke the model""" - model = config["configurable"].get("ll_client", None) - logger.debug("Invoking on: %s", state["cleaned_messages"]) try: - response = await model.ainvoke(state["cleaned_messages"]) + # Await the asynchronous completion with streaming enabled + logger.info("Streaming completion...") + prompt_messages = get_messages(state, config) + + # ll_raw holds either a dict(litellm) or an object(client) + ll_raw = config["configurable"].get("ll_config", {}) + response = await acompletion(messages=prompt_messages, stream=True, **ll_raw) + async for chunk in response: + content = chunk.choices[0].delta.content + if content is not None: + writer({"stream": content}) + collected_content.append(content) + full_response.append(chunk) + + # After loop: update last chunk to a full completion with usage details + if full_response: + last_chunk = full_response[-1] + full_text = "".join(collected_content) + last_chunk.object = "chat.completion" + last_chunk.choices[0].message = {"role": "assistant", "content": full_text} + delattr(last_chunk.choices[0], "delta") + last_chunk.choices[0].finish_reason = "stop" + final_response = last_chunk.model_dump() + + writer({"completion": final_response}) except Exception as ex: - if hasattr(ex, "message"): - response = ("assistant", f"I'm sorry: {ex.message}") - else: - raise - return {"messages": [response]} + logger.error(ex) + full_text = f"I'm sorry, a completion problem occurred: {str(ex).split('Traceback', 1)[0]}" + return {"messages": [AIMessage(content=full_text)]} -############################################################################# -# GRAPH -############################################################################# -workflow = StateGraph(AgentState) - -# Define the nodes -workflow.add_node("agent", agent) -workflow.add_node("vs_retrieve", vs_retrieve) -workflow.add_node("vs_generate", vs_generate) -workflow.add_node("selectai_generate", selectai_generate) -workflow.add_node("generate_response", generate_response) -workflow.add_node("respond", respond) - -# Start the agent with clean messages -workflow.add_edge(START, "agent") -# Branch to either "selectai_generate", "vs_retrieve", or "generate_response" -workflow.add_conditional_edges("agent", use_tool) -workflow.add_edge("generate_response", "respond") +# Build the state graph +workflow = StateGraph(OptimizerState) +workflow.add_node("stream_completion", stream_completion) -# If selectAI -workflow.add_edge("selectai_generate", "respond") +workflow.add_edge(START, "stream_completion") +workflow.add_edge("stream_completion", END) -# If retrieving, grade the documents returned and either generate (not relevant) or vs_generate (relevant) -workflow.add_conditional_edges("vs_retrieve", grade_documents) -workflow.add_edge("vs_generate", "respond") - -# Finish with OpenAI Compatible Response -workflow.add_edge("respond", END) - -# Compile +# Compile the graph memory = MemorySaver() chatbot_graph = workflow.compile(checkpointer=memory) - -## This will output the Graph in ascii; don't deliver uncommented -# chatbot_graph.get_graph(xray=True).print_ascii() From eba62b0c814831529c86593c3a189f03fb4c102e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 29 Aug 2025 11:19:47 +0100 Subject: [PATCH 08/31] litellm_config --- src/server/api/utils/models.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index a6e40909..1039b5ea 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -7,16 +7,6 @@ from urllib.parse import urlparse from litellm import get_supported_openai_params -from openai import OpenAI - -from langchain_core.language_models.chat_models import BaseChatModel -from langchain.chat_models import init_chat_model -from langchain.embeddings import init_embeddings - -from langchain_community.chat_models.oci_generative_ai import ChatOCIGenAI -from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings - -from giskard.llm.client.openai import OpenAIClient import server.api.utils.oci as util_oci import server.api.core.models as core_models @@ -86,7 +76,7 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model return genai_models -def get_litellm_client( +def get_litellm_config( model_config: dict, oci_config: schema.OracleCloudSettings = None, giskard: bool = False ) -> dict: """Establish client""" From 62166f16335099340247bb198967b0fd40ae1d58 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Fri, 29 Aug 2025 11:43:04 +0100 Subject: [PATCH 09/31] Prompts --- src/server/agents/chatbot.py | 26 +++-- src/server/api/utils/chat.py | 182 ++++++++++++++++++----------------- 2 files changed, 113 insertions(+), 95 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index dc7c73b6..b3615855 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -7,7 +7,7 @@ from langgraph.checkpoint.memory import MemorySaver from langgraph.config import get_stream_writer from langgraph.graph import StateGraph, START, END, MessagesState -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig from litellm import acompletion @@ -20,6 +20,7 @@ class OptimizerState(MessagesState): """Establish our Agent State Machine""" + cleaned_messages: list # Messages w/o VS Results final_response: dict # OpenAI Response @@ -37,6 +38,12 @@ def get_messages(state: OptimizerState, config: RunnableConfig) -> list: # If user decided for no history, only take the last message state_messages = state_messages if use_history else state_messages[-1:] + # Remove SystemMessage (prompts) + state_messages = [msg for msg in state_messages if not isinstance(msg, SystemMessage)] + + # Add our new prompt + state_messages.insert(0, SystemMessage(content=config["metadata"]["sys_prompt"].prompt)) + prompt_messages = [{"role": "user", "content": m.content} for m in state_messages] return prompt_messages @@ -45,6 +52,13 @@ def get_messages(state: OptimizerState, config: RunnableConfig) -> list: ############################################################################# # NODES and EDGES ############################################################################# +async def initialise(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Initialise our chatbot""" + logger.debug("Initializing Chatbot") + messages = get_messages(state, config) + return {"cleaned_messages": messages} + + async def stream_completion(state: OptimizerState, config: RunnableConfig | None = None): """LiteLLM streaming wrapper""" writer = get_stream_writer() @@ -54,11 +68,8 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig | None try: # Await the asynchronous completion with streaming enabled logger.info("Streaming completion...") - prompt_messages = get_messages(state, config) - - # ll_raw holds either a dict(litellm) or an object(client) ll_raw = config["configurable"].get("ll_config", {}) - response = await acompletion(messages=prompt_messages, stream=True, **ll_raw) + response = await acompletion(messages=state["cleaned_messages"], stream=True, **ll_raw) async for chunk in response: content = chunk.choices[0].delta.content if content is not None: @@ -86,9 +97,12 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig | None # Build the state graph workflow = StateGraph(OptimizerState) +workflow.add_node("initialise", initialise) workflow.add_node("stream_completion", stream_completion) -workflow.add_edge(START, "stream_completion") +# Start the chatbot with clean messages +workflow.add_edge(START, "initialise") +workflow.add_edge("initialise", "stream_completion") workflow.add_edge("stream_completion", END) # Compile the graph diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 4a747aec..e9a23194 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -9,15 +9,13 @@ from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig -from langgraph.graph.state import CompiledStateGraph - import server.api.core.settings as core_settings import server.api.core.oci as core_oci import server.api.core.prompts as core_prompts import server.api.utils.models as util_models import server.api.utils.databases as util_databases -import server.agents.chatbot as chatbot +from server.agents.chatbot import chatbot_graph import server.api.utils.selectai as util_selectai import common.schema as schema @@ -42,96 +40,102 @@ async def completion_generator( oci_config = core_oci.get_oci(client=client) # Setup Client Model - ll_client = util_models.get_client(model, oci_config) - if not ll_client: - error_response = { - "id": "error", - "choices": [ - { - "message": { - "role": "assistant", - "content": "I'm unable to initialise the Language Model. Please refresh the application.", - }, - "index": 0, - "finish_reason": "stop", - } - ], - "created": int(time.time()), - "model": model.get("model", "unknown"), - "object": "chat.completion", - } - yield error_response - return - - # Get Prompts - try: - user_sys_prompt = getattr(client_settings.prompts, "sys", "Basic Example") - sys_prompt = core_prompts.get_prompts(category="sys", name=user_sys_prompt) - except AttributeError as ex: - # schema.Settings not on server-side - logger.error("A settings exception occurred: %s", ex) - raise - - db_conn = None - # Setup selectai - if client_settings.selectai.enabled: - db_conn = util_databases.get_client_db(client).connection - util_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) - util_selectai.set_profile( - db_conn, client_settings.selectai.profile, "max_tokens", model["max_completion_tokens"] - ) - - # Setup vector_search - embed_client, ctx_prompt = None, None - if client_settings.vector_search.enabled: - db_conn = util_databases.get_client_db(client).connection - embed_client = util_models.get_client(client_settings.vector_search.model_dump(), oci_config) - - user_ctx_prompt = getattr(client_settings.prompts, "ctx", "Basic Example") - ctx_prompt = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) + ll_config = util_models.get_litellm_config(model, oci_config) + # Start to establish our LangGraph Args kwargs = { + "stream_mode": "custom", "input": {"messages": [HumanMessage(content=request.messages[0].content)]}, "config": RunnableConfig( - configurable={ - "thread_id": client, - "ll_client": ll_client, - "embed_client": embed_client, - "db_conn": db_conn, - }, - metadata={ - "model_id": model["model"], - "use_history": client_settings.ll_model.chat_history, - "vector_search": client_settings.vector_search, - "selectai": client_settings.selectai, - "sys_prompt": sys_prompt, - "ctx_prompt": ctx_prompt, - }, + configurable={"thread_id": client, "ll_config": ll_config}, + metadata={"use_history": client_settings.ll_model.chat_history}, ), } + + # Get System Prompt + user_sys_prompt = getattr(client_settings.prompts, "sys", "Basic Example") + kwargs["config"]["metadata"]["sys_prompt"] = core_prompts.get_prompts(category="sys", name=user_sys_prompt) + + # db_conn = None + # # Setup selectai + # if client_settings.selectai.enabled: + # db_conn = util_databases.get_client_db(client).connection + # util_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) + # util_selectai.set_profile( + # db_conn, client_settings.selectai.profile, "max_tokens", model["max_completion_tokens"] + # ) + + # # Setup vector_search + # embed_client, ctx_prompt = None, None + # if client_settings.vector_search.enabled: + # db_conn = util_databases.get_client_db(client).connection + # embed_client = util_models.get_client(client_settings.vector_search.model_dump(), oci_config) + + # user_ctx_prompt = getattr(client_settings.prompts, "ctx", "Basic Example") + # ctx_prompt = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) + + # kwargs = { + # "stream_mode": "custom", + # "input": {"messages": [HumanMessage(content=request.messages[0].content)]}, + # "config": RunnableConfig( + # configurable={ + # "thread_id": client, + # "ll_config": ll_client, + # "embed_client": embed_client, + # "db_conn": db_conn, + # }, + # metadata={ + # "model_id": model["model"], + # "use_history": client_settings.ll_model.chat_history, + # "vector_search": client_settings.vector_search, + # "selectai": client_settings.selectai, + # "sys_prompt": sys_prompt, + # "ctx_prompt": ctx_prompt, + # }, + # ), + # } logger.debug("Completion Kwargs: %s", kwargs) - agent: CompiledStateGraph = chatbot.chatbot_graph - try: - async for chunk in agent.astream_events(**kwargs, version="v2"): - # The below will produce A LOT of output; uncomment when desperate - # logger.debug("Streamed Chunk: %s", chunk) - if chunk["event"] == "on_chat_model_stream": - if "tools_condition" in str(chunk["metadata"]["langgraph_triggers"]): - continue # Skip Tool Call messages - if "vs_retrieve" in str(chunk["metadata"]["langgraph_node"]): - continue # Skip Fake-Tool Call messages - content = chunk["data"]["chunk"].content - if content != "" and call == "streams": - yield content.encode("utf-8") - last_response = chunk["data"] - if call == "streams": - yield "[stream_finished]" # This will break the Chatbot loop - elif call == "completions": - final_response = last_response["output"]["final_response"] - yield final_response # This will be captured for ChatResponse - except Exception as ex: - logger.error("An invoke exception occurred: %s", ex) - # yield f"I'm sorry; {ex}" - # TODO(gotsysdba) - If a message is returned; - # format and return (this should be done in the agent) - raise + final_response = None + async for output in chatbot_graph.astream(**kwargs): + if "stream" in output: + yield output["stream"].encode("utf-8") + if "completion" in output: + final_response = output["completion"] + if call == "streams": + yield "[stream_finished]" # This will break the Chatbot loop + if call == "completions" and final_response is not None: + yield final_response # This will be captured for ChatResponse + + # print(f"********** output: {output["stream"][]}") + # for chunk in output["llm"]["messages"]: + # print(f"********** chunk: {chunk}") + # print(f"********** yield: {chunk.content}") + # yield chunk.content.encode("utf-8") + + # result = await graph.ainvoke(**kwargs) + # print("\n\nFinal result:", result) + + # try: + # async for chunk in agent.astream_events(**kwargs, version="v2", stream=True): + # # The below will produce A LOT of output; uncomment when desperate + # # logger.debug("Streamed Chunk: %s", chunk) + # if chunk["event"] == "on_chat_model_stream": + # if "tools_condition" in str(chunk["metadata"]["langgraph_triggers"]): + # continue # Skip Tool Call messages + # if "vs_retrieve" in str(chunk["metadata"]["langgraph_node"]): + # continue # Skip Fake-Tool Call messages + # content = chunk["data"]["chunk"].content + # if content != "" and call == "streams": + # yield content.encode("utf-8") + # last_response = chunk["data"] + # if call == "streams": + # yield "[stream_finished]" # This will break the Chatbot loop + # elif call == "completions": + # final_response = last_response["output"]["final_response"] + # yield final_response # This will be captured for ChatResponse + # except Exception as ex: + # logger.error("An invoke exception occurred: %s", ex) + # # yield f"I'm sorry; {ex}" + # # TODO(gotsysdba) - If a message is returned; + # # format and return (this should be done in the agent) + # raise From 41498a8dbbd29dd8c5959ce38214bd93465d5505 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 31 Aug 2025 07:35:29 +0100 Subject: [PATCH 10/31] All but SelectAI --- src/server/agents/chatbot.py | 263 +++++++++++++++++++++++++++--- src/server/api/utils/chat.py | 16 +- src/server/api/utils/databases.py | 4 +- src/server/api/utils/models.py | 45 ++++- 4 files changed, 299 insertions(+), 29 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index b3615855..4773678d 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -2,51 +2,139 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore litellm checkpointer acompletion astream +# spell-checker:ignore acompletion checkpointer litellm mult oraclevs vectorstores selectai + +import copy +import decimal +import json +from typing import Literal from langgraph.checkpoint.memory import MemorySaver from langgraph.config import get_stream_writer from langgraph.graph import StateGraph, START, END, MessagesState -from langchain_core.messages import AIMessage, SystemMessage +from langchain_core.documents.base import Document +from langchain_core.messages import AIMessage, SystemMessage, ToolMessage +from langchain_core.messages.utils import convert_to_openai_messages +from langchain_core.prompts import PromptTemplate from langchain_core.runnables import RunnableConfig -from litellm import acompletion + +from langchain_community.vectorstores.oraclevs import OracleVS + +from litellm import acompletion, completion +from litellm.exceptions import APIConnectionError from common import logging_config logger = logging_config.logging.getLogger("server.agents.chatbot") +class DecimalEncoder(json.JSONEncoder): + """Used with json.dumps to encode decimals""" + + def default(self, o): + if isinstance(o, decimal.Decimal): + return str(o) + return super().default(o) + + class OptimizerState(MessagesState): """Establish our Agent State Machine""" cleaned_messages: list # Messages w/o VS Results + context_input: str # Contextualized User Input (for VS) + documents: dict # VectorStore documents final_response: dict # OpenAI Response ############################################################################# # Functions ############################################################################# -def get_messages(state: OptimizerState, config: RunnableConfig) -> list: +def clean_messages(state: OptimizerState, config: RunnableConfig) -> list: """Return a list of messages that will be passed to the model for completion Filter out old VS documents to avoid blowing-out the context window - Leave the state as is for GUI functionality""" + Leave the state as is (deepcopy) for GUI functionality""" + use_history = config["metadata"]["use_history"] - state_messages = state.get("messages", []) + state_messages = copy.deepcopy(state.get("messages", [])) if state_messages: # If user decided for no history, only take the last message state_messages = state_messages if use_history else state_messages[-1:] - # Remove SystemMessage (prompts) - state_messages = [msg for msg in state_messages if not isinstance(msg, SystemMessage)] + # Remove System Prompt from top + if isinstance(state_messages[0], SystemMessage): + state_messages.pop(0) + + # Remove ToolCalls + state_messages = [msg for msg in state_messages if not isinstance(msg, ToolMessage)] + + return state_messages + + +def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "stream_completion"]: + """Conditional edge to determine if using SelectAI, Vector Search or not""" + # selectai_enabled = config["metadata"]["selectai"].enabled + # if selectai_enabled: + # logger.info("Invoking Chatbot with SelectAI: %s", selectai_enabled) + # return "selectai" + + enabled = config["metadata"]["vector_search"].enabled + if enabled: + logger.info("Invoking Chatbot with Vector Search: %s", enabled) + return "vs_retrieve" + + return "stream_completion" + + +def rephrase(state: OptimizerState, config: RunnableConfig) -> str: + """Take our contextualization prompt and reword the last user prompt""" + ctx_prompt = config.get("metadata", {}).get("ctx_prompt") + retrieve_question = state["messages"][-1].content + + if config["metadata"]["use_history"] and ctx_prompt and len(state["messages"]) > 2: + ctx_template = """ + {prompt} + Here is the context and history: + ------- + {history} + ------- + Here is the user input: + ------- + {question} + ------- + Return ONLY the rephrased query without any explanation or additional text. + """ + rephrase_template = PromptTemplate( + template=ctx_template, + input_variables=["ctx_prompt", "history", "question"], + ) + formatted_prompt = rephrase_template.format( + prompt=ctx_prompt.prompt, history=state["messages"], question=retrieve_question + ) + ll_raw = config["configurable"]["ll_config"] + try: + response = completion(messages=[{"role": "system", "content": formatted_prompt}], stream=False, **ll_raw) + print(f"************ {response}") + + context_question = response.choices[0].message.content + except APIConnectionError as ex: + logger.error("Failed to rephrase: %s", str(ex)) - # Add our new prompt - state_messages.insert(0, SystemMessage(content=config["metadata"]["sys_prompt"].prompt)) + if context_question != retrieve_question: + logger.info( + "**** Replacing User Question: %s with contextual one: %s", retrieve_question, context_question + ) + retrieve_question = context_question - prompt_messages = [{"role": "user", "content": m.content} for m in state_messages] + return retrieve_question - return prompt_messages + +def document_formatter(rag_context) -> str: + """Extract the Vector Search Documents and format into a string""" + logger.info("Extracting chunks from Vector Search Retrieval") + chunks = "\n\n".join([doc["page_content"] for doc in rag_context]) + return chunks ############################################################################# @@ -55,21 +143,149 @@ def get_messages(state: OptimizerState, config: RunnableConfig) -> list: async def initialise(state: OptimizerState, config: RunnableConfig) -> OptimizerState: """Initialise our chatbot""" logger.debug("Initializing Chatbot") - messages = get_messages(state, config) - return {"cleaned_messages": messages} + cleaned_messages = clean_messages(state, config) + return {"cleaned_messages": cleaned_messages} + + +async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Determines whether the retrieved documents are relevant to the question.""" + logger.info("Grading Vector Search Response using %i retrieved documents", len(state["documents"])) + # Initialise documents as relevant + relevant = "yes" + if config["metadata"]["vector_search"].grading and state.get("documents"): + grade_template = """ + You are a Grader assessing the relevance of retrieved text to the user's input. + You MUST respond with a only a binary score of 'yes' or 'no'. + If you DO find ANY relevant retrieved text to the user's input, return 'yes' immediately and stop grading. + If you DO NOT find relevant retrieved text to the user's input, return 'no'. + Here is the user input: + ------- + {question} + ------- + Here is the retrieved text: + ------- + {documents} + """ + grade_template = PromptTemplate( + template=grade_template, + input_variables=["question", "documents"], + ) + documents_dict = document_formatter(state["documents"]) + question = state["context_input"] + formatted_prompt = grade_template.format(question=question, documents=documents_dict) + logger.debug("Grading Prompt: %s", formatted_prompt) + ll_raw = config["configurable"]["ll_config"] + + # Grade + try: + response = await acompletion( + messages=[{"role": "system", "content": formatted_prompt}], stream=False, **ll_raw + ) + print(f"************ {response}") + relevant = response["choices"][0]["message"]["content"] + logger.info("Grading completed. Relevant: %s", relevant) + if relevant not in ("yes", "no"): + logger.error("LLM did not return binary relevant in grader; assuming all results relevant.") + except APIConnectionError as ex: + logger.error("Failed to grade; marking all results relevant: %s", str(ex)) + else: + logger.info("Vector Search Grading disabled; assuming all results relevant.") + + if relevant.lower() == "yes": + # This is where we fake a tools response before the completion. + logger.debug("Creating ToolMessage Documents: %s", state["documents"]) + logger.debug("Creating ToolMessage ContextQ: %s", state["context_input"]) + + state["messages"].append( + ToolMessage( + content=json.dumps([state["documents"], state["context_input"]], cls=DecimalEncoder), + name="oraclevs_tool", + tool_call_id="tool_placeholder", + ) + ) + logger.debug("ToolMessage Created") + return {"documents": documents_dict} + else: + return {"documents": dict()} + + +async def vs_retrieve(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Search and return information using Vector Search""" + ## Note that this should be a tool call; but some models (Perplexity/OCI GenAI) + ## have limited or no tools support. Instead we'll call as part of the pipeline + ## and fake a tools call. This can be later reverted to a tool without much code change. + retrieve_question = rephrase(state, config) + logger.info("Perform Vector Search with: %s", retrieve_question) + + try: + logger.info("Connecting to VectorStore") + db_conn = config["configurable"]["db_conn"] + embed_client = config["configurable"]["embed_client"] + vector_search = config["metadata"]["vector_search"] + logger.info("Initializing Vector Store: %s", vector_search.vector_store) + try: + vectorstores = OracleVS(db_conn, embed_client, vector_search.vector_store, vector_search.distance_metric) + except Exception as ex: + logger.exception("Failed to initialize the Vector Store") + raise ex + + try: + search_type = vector_search.search_type + search_kwargs = {"k": vector_search.top_k} + + if search_type == "Similarity": + retriever = vectorstores.as_retriever(search_type="similarity", search_kwargs=search_kwargs) + elif search_type == "Similarity Score Threshold": + search_kwargs["score_threshold"] = vector_search.score_threshold + retriever = vectorstores.as_retriever( + search_type="similarity_score_threshold", search_kwargs=search_kwargs + ) + elif search_type == "Maximal Marginal Relevance": + search_kwargs.update( + { + "fetch_k": vector_search.fetch_k, + "lambda_mult": vector_search.lambda_mult, + } + ) + retriever = vectorstores.as_retriever(search_type="mmr", search_kwargs=search_kwargs) + else: + raise ValueError(f"Unsupported search_type: {search_type}") + logger.info("Invoking retriever on: %s", retrieve_question) + documents = retriever.invoke(retrieve_question) + except Exception as ex: + logger.exception("Failed to perform Oracle Vector Store retrieval") + raise ex + except (AttributeError, KeyError, TypeError) as ex: + documents = Document( + id="DocumentException", page_content="I'm sorry, I think you found a bug!", metadata={"source": f"{ex}"} + ) + documents_dict = [vars(doc) for doc in documents] + logger.info("Found Documents: %i", len(documents_dict)) + return {"context_input": retrieve_question, "documents": documents_dict} -async def stream_completion(state: OptimizerState, config: RunnableConfig | None = None): +async def stream_completion(state: OptimizerState, config: RunnableConfig) -> OptimizerState: """LiteLLM streaming wrapper""" writer = get_stream_writer() full_response = [] collected_content = [] + messages = state["cleaned_messages"] try: + # Get our Prompt + sys_prompt = config.get("metadata", {}).get("sys_prompt") + if state.get("context_input") and state.get("documents"): + documents = state["documents"] + new_prompt = SystemMessage(content=f"{sys_prompt.prompt}\n {documents}") + else: + new_prompt = SystemMessage(content=f"{sys_prompt.prompt}") + + # Insert Prompt into cleaned_messages + messages.insert(0, new_prompt) # Await the asynchronous completion with streaming enabled logger.info("Streaming completion...") - ll_raw = config["configurable"].get("ll_config", {}) - response = await acompletion(messages=state["cleaned_messages"], stream=True, **ll_raw) + ll_raw = config["configurable"]["ll_config"] + response = await acompletion(messages=convert_to_openai_messages(messages), stream=True, **ll_raw) async for chunk in response: content = chunk.choices[0].delta.content if content is not None: @@ -91,18 +307,27 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig | None except Exception as ex: logger.error(ex) full_text = f"I'm sorry, a completion problem occurred: {str(ex).split('Traceback', 1)[0]}" - return {"messages": [AIMessage(content=full_text)]} # Build the state graph workflow = StateGraph(OptimizerState) workflow.add_node("initialise", initialise) +workflow.add_node("rephrase", rephrase) +workflow.add_node("vs_retrieve", vs_retrieve) +workflow.add_node("vs_grade", vs_grade) workflow.add_node("stream_completion", stream_completion) # Start the chatbot with clean messages workflow.add_edge(START, "initialise") -workflow.add_edge("initialise", "stream_completion") + +# Branch to either "selectai", "vs_retrieve", or "generate_response" +workflow.add_conditional_edges("initialise", use_tool) +# workflow.add_edge("selectai", "stream_completion") +workflow.add_edge("vs_retrieve", "vs_grade") +workflow.add_edge("vs_grade", "stream_completion") + +# End the workflow workflow.add_edge("stream_completion", END) # Compile the graph diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index e9a23194..711145df 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -48,7 +48,11 @@ async def completion_generator( "input": {"messages": [HumanMessage(content=request.messages[0].content)]}, "config": RunnableConfig( configurable={"thread_id": client, "ll_config": ll_config}, - metadata={"use_history": client_settings.ll_model.chat_history}, + metadata={ + "use_history": client_settings.ll_model.chat_history, + "vector_search": client_settings.vector_search, + "selectai": client_settings.selectai, + }, ), } @@ -56,6 +60,16 @@ async def completion_generator( user_sys_prompt = getattr(client_settings.prompts, "sys", "Basic Example") kwargs["config"]["metadata"]["sys_prompt"] = core_prompts.get_prompts(category="sys", name=user_sys_prompt) + # Setup Vector Search + if client_settings.vector_search.enabled: + kwargs["config"]["configurable"]["db_conn"] = util_databases.get_client_db(client, False).connection + kwargs["config"]["configurable"]["embed_client"] = util_models.get_embed_client( + client_settings.vector_search.model_dump(), oci_config + ) + # Get Context Prompt + user_ctx_prompt = getattr(client_settings.prompts, "ctx", "Basic Example") + kwargs["config"]["metadata"]["ctx_prompt"] = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) + # db_conn = None # # Setup selectai # if client_settings.selectai.enabled: diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 8b43598e..d65ff7d8 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -35,7 +35,7 @@ def drop_vs(conn: oracledb.Connection, vs: schema.VectorStoreTableType) -> None: LangchainVS.drop_table_purge(conn, vs) -def get_client_db(client: schema.ClientIdType) -> schema.Database: +def get_client_db(client: schema.ClientIdType, validate: bool = True) -> schema.Database: """Return a Database Object based on client settings""" client_settings = core_settings.get_client_settings(client) @@ -47,7 +47,7 @@ def get_client_db(client: schema.ClientIdType) -> schema.Database: db_name = getattr(client_settings.vector_search, "database", "DEFAULT") # Return the Database Object - db = core_databases.get_databases(db_name) + db = core_databases.get_databases(name=db_name, validate=validate) # Ping the Database test(db) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 1039b5ea..085391b0 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -8,6 +8,11 @@ from litellm import get_supported_openai_params +from langchain.embeddings import init_embeddings +from langchain_litellm import ChatLiteLLM +from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings +from langchain_core.embeddings.embeddings import Embeddings + import server.api.utils.oci as util_oci import server.api.core.models as core_models @@ -76,11 +81,8 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model return genai_models -def get_litellm_config( - model_config: dict, oci_config: schema.OracleCloudSettings = None, giskard: bool = False -) -> dict: - """Establish client""" - logger.debug("Model Client: %s; OCI Config: %s; Giskard: %s", model_config, oci_config, giskard) +def _get_full_config(model_config: dict, oci_config: schema.OracleCloudSettings = None) -> tuple[dict, str]: + logger.debug("Model Client: %s; OCI Config: %s", model_config, oci_config) try: defined_model = core_models.get_model( @@ -92,9 +94,15 @@ def get_litellm_config( # Merge configurations, skipping None values full_model_config = {**defined_model, **{k: v for k, v in model_config.items() if v is not None}} + provider = full_model_config.pop("provider") + provider = "openai" if provider == "openai_compatible" else provider + + return full_model_config, provider + - # Determine provider and model name - provider = "openai" if full_model_config["provider"] == "openai_compatible" else full_model_config["provider"] +def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSettings = None) -> dict: + """Establish client""" + full_model_config, provider = _get_full_config(model_config, oci_config) model_name = f"{provider}/{full_model_config['id']}" # Get supported parameters and initialize config @@ -128,3 +136,26 @@ def get_litellm_config( ) return litellm_config + + +def get_embed_client(model_config: dict, oci_config: schema.OracleCloudSettings) -> Embeddings: + """Retrieve embedding model client""" + full_model_config, provider = _get_full_config(model_config, oci_config) + client = None + + if provider != "oci": + kwargs = { + "provider": "openai" if provider == "openai_compatible" else provider, + "model": full_model_config["id"], + "base_url": full_model_config.get("api_base"), + } + if full_model_config.get("api_key"): # only add if set + kwargs["api_key"] = full_model_config["api_key"] + client = init_embeddings(**kwargs) + else: + client = OCIGenAIEmbeddings( + model_id=full_model_config["id"], + client=util_oci.init_genai_client(oci_config), + compartment_id=oci_config.genai_compartment_id, + ) + return client From 94764de4dc6417fb21528106467c21bbd4f224ef Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 31 Aug 2025 08:25:47 +0100 Subject: [PATCH 11/31] Workaround for XAI --- src/client/utils/st_common.py | 48 ++++++++++++++++++---------------- src/common/schema.py | 4 ++- src/pyproject.toml | 20 +++++++------- src/server/api/utils/models.py | 7 +++-- src/server/api/utils/oci.py | 24 ++++++++--------- 5 files changed, 55 insertions(+), 48 deletions(-) diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index 983b740d..cede5159 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -55,6 +55,8 @@ def enabled_models_lookup(model_type: str) -> dict[str, dict[str, Any]]: def bool_to_emoji(value): "Return an Emoji for Bools" return "✅" if value else "âšĒ" + + def local_file_payload(uploaded_files: Union[BytesIO, list[BytesIO]]) -> list: """Upload Single file from Streamlit to the Server""" # If it's a single file, convert it to a list for consistent processing @@ -207,8 +209,8 @@ def ll_sidebar() -> None: on_change=update_client_settings("ll_model"), ) - # Top P if not state.client_settings["selectai"]["enabled"]: + # Top P st.sidebar.slider( "Top P (Default: 1.0):", help=help_text.help_dict["top_p"], @@ -220,28 +222,29 @@ def ll_sidebar() -> None: ) # Frequency Penalty - frequency_penalty = ll_models_enabled[selected_model]["frequency_penalty"] - user_frequency_penalty = state.client_settings["ll_model"]["frequency_penalty"] - st.sidebar.slider( - f"Frequency penalty (Default: {frequency_penalty}):", - help=help_text.help_dict["frequency_penalty"], - value=user_frequency_penalty if user_frequency_penalty is not None else frequency_penalty, - min_value=-2.0, - max_value=2.0, - key="selected_ll_model_frequency_penalty", - on_change=update_client_settings("ll_model"), - ) + if "xai" not in state.client_settings["ll_model"]["model"]: + frequency_penalty = ll_models_enabled[selected_model]["frequency_penalty"] + user_frequency_penalty = state.client_settings["ll_model"]["frequency_penalty"] + st.sidebar.slider( + f"Frequency penalty (Default: {frequency_penalty}):", + help=help_text.help_dict["frequency_penalty"], + value=user_frequency_penalty if user_frequency_penalty is not None else frequency_penalty, + min_value=-2.0, + max_value=2.0, + key="selected_ll_model_frequency_penalty", + on_change=update_client_settings("ll_model"), + ) - # Presence Penalty - st.sidebar.slider( - "Presence penalty (Default: 0.0):", - help=help_text.help_dict["presence_penalty"], - value=state.client_settings["ll_model"]["presence_penalty"], - min_value=-2.0, - max_value=2.0, - key="selected_ll_model_presence_penalty", - on_change=update_client_settings("ll_model"), - ) + # Presence Penalty + st.sidebar.slider( + "Presence penalty (Default: 0.0):", + help=help_text.help_dict["presence_penalty"], + value=state.client_settings["ll_model"]["presence_penalty"], + min_value=-2.0, + max_value=2.0, + key="selected_ll_model_presence_penalty", + on_change=update_client_settings("ll_model"), + ) ##################################################### @@ -431,6 +434,7 @@ def vector_search_sidebar() -> None: database_lookup = state_configs_lookup("database_configs", "name") vs_df = pd.DataFrame(database_lookup[db_alias].get("vector_stores")) + def vs_reset() -> None: """Reset Vector Store Selections""" for key in state.client_settings["vector_search"]: diff --git a/src/common/schema.py b/src/common/schema.py index e5b062d7..e2296f03 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -113,7 +113,9 @@ class LanguageModelParameters(BaseModel): context_length: Optional[int] = Field(default=None, description="The context window for Language Model.") frequency_penalty: Optional[float] = Field(description=help_text.help_dict["frequency_penalty"], default=0.00) - max_completion_tokens: Optional[int] = Field(description=help_text.help_dict["max_completion_tokens"], default=256) + max_completion_tokens: Optional[int] = Field( + description=help_text.help_dict["max_completion_tokens"], default=4096 + ) presence_penalty: Optional[float] = Field(description=help_text.help_dict["presence_penalty"], default=0.00) temperature: Optional[float] = Field(description=help_text.help_dict["temperature"], default=1.00) top_p: Optional[float] = Field(description=help_text.help_dict["top_p"], default=1.00) diff --git a/src/pyproject.toml b/src/pyproject.toml index 3b40af3d..14b14d77 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -14,7 +14,7 @@ authors = [ # Common dependencies that are always needed dependencies = [ - "langchain-core==0.3.74", + "langchain-core==0.3.75", "httpx==0.28.1", "oracledb~=3.1", "plotly==6.2.0", @@ -23,28 +23,28 @@ dependencies = [ [project.optional-dependencies] # Server component dependencies server = [ - "bokeh==3.7.3", + "bokeh==3.8.0", "evaluate==0.4.5", "fastapi==0.116.1", - "faiss-cpu==1.11.0.post1", - "giskard==2.17.0", + "faiss-cpu==1.12.0", + "giskard==2.18.0", "langchain-anthropic==0.3.19", "langchain-azure-ai==0.1.5", "langchain-aws==0.2.31", "langchain-cohere==0.4.5", - "langchain-community==0.3.27", + "langchain-community==0.3.29", "langchain-deepseek==0.1.4", - "langchain-google-genai==2.1.9", + "langchain-google-genai==2.1.10", "langchain-google-vertexai==2.0.28", "langchain-groq==0.3.7", "langchain-huggingface==0.3.1", "langchain-mistralai==0.2.11", - "langchain-ollama==0.3.6", - "langchain-openai==0.3.29", + "langchain-ollama==0.3.7", + "langchain-openai==0.3.32", "langchain-perplexity==0.1.2", "langchain-xai==0.2.5", "langgraph==0.6.4", - "litellm==1.75.3", + "litellm==1.76.1", "llama-index==0.13.1", "lxml==6.0.0", "matplotlib==3.10.5", @@ -58,7 +58,7 @@ server = [ # GUI component dependencies client = [ - "streamlit==1.48.0", + "streamlit==1.49.1", ] # Test dependencies diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 085391b0..ad69c082 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -9,7 +9,6 @@ from litellm import get_supported_openai_params from langchain.embeddings import init_embeddings -from langchain_litellm import ChatLiteLLM from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings from langchain_core.embeddings.embeddings import Embeddings @@ -120,8 +119,11 @@ def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSetting # Always force the path path = "/compatibility/v1" full_model_config["api_base"] = f"{scheme}://{netloc}{path}" + if "xai" in model_name: + litellm_config.pop("presence_penalty", None) + litellm_config.pop("frequency_penalty", None) - litellm_config.update({"model": model_name, "api_base": full_model_config.get("api_base")}) + litellm_config.update({"model": model_name, "api_base": full_model_config.get("api_base"), "drop_params": True}) if provider == "oci": litellm_config.update( @@ -134,6 +136,7 @@ def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSetting "oci_compartment_id": oci_config.genai_compartment_id, } ) + logger.debug("LiteLLM Config: %s", litellm_config) return litellm_config diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 76f18b2d..6fe9ec47 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -91,7 +91,7 @@ def get_namespace(config: OracleCloudSettings = None) -> str: namespace = client.get_namespace().data logger.info("OCI: Namespace = %s", namespace) except oci.exceptions.InvalidConfig as ex: - raise OciException(status_code=400, detail=f"Invalid Config") from ex + raise OciException(status_code=400, detail="Invalid Config") from ex except oci.exceptions.ServiceError as ex: raise OciException(status_code=401, detail="AuthN Error") from ex except FileNotFoundError as ex: @@ -165,18 +165,16 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis # Build our list of models for model in response.data.items: - # note that langchain_community.llms.oci_generative_ai only supports meta/cohere models - if model.display_name not in excluded_display_names and model.vendor in ["meta", "cohere"]: - genai_models.append( - { - "region": region["region_name"], - "compartment_id": config.genai_compartment_id, - "model_name": model.display_name, - "capabilities": model.capabilities, - "vendor": model.vendor, - "id": model.id, - } - ) + genai_models.append( + { + "region": region["region_name"], + "compartment_id": config.genai_compartment_id, + "model_name": model.display_name, + "capabilities": model.capabilities, + "vendor": model.vendor, + "id": model.id, + } + ) except oci.exceptions.ServiceError: logger.info("Region: %s has no GenAI services", region["region_name"]) except (oci.exceptions.RequestException, urllib3.exceptions.MaxRetryError): From 6ffe0a7fe8b36a975c14e7ac1f495ed68841e270 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Sun, 31 Aug 2025 08:52:28 +0100 Subject: [PATCH 12/31] SelectAI --- src/server/agents/chatbot.py | 48 ++++++++++++++++++--- src/server/api/utils/chat.py | 84 +++++------------------------------- 2 files changed, 52 insertions(+), 80 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index 4773678d..fa67816e 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -24,6 +24,8 @@ from litellm import acompletion, completion from litellm.exceptions import APIConnectionError +from server.api.core.databases import execute_sql + from common import logging_config logger = logging_config.logging.getLogger("server.agents.chatbot") @@ -72,12 +74,12 @@ def clean_messages(state: OptimizerState, config: RunnableConfig) -> list: return state_messages -def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "stream_completion"]: +def use_tool(_, config: RunnableConfig) -> Literal["vs_retrieve", "selectai_completion", "stream_completion"]: """Conditional edge to determine if using SelectAI, Vector Search or not""" - # selectai_enabled = config["metadata"]["selectai"].enabled - # if selectai_enabled: - # logger.info("Invoking Chatbot with SelectAI: %s", selectai_enabled) - # return "selectai" + selectai_enabled = config["metadata"]["selectai"].enabled + if selectai_enabled: + logger.info("Invoking Chatbot with SelectAI: %s", selectai_enabled) + return "selectai_completion" enabled = config["metadata"]["vector_search"].enabled if enabled: @@ -264,6 +266,37 @@ async def vs_retrieve(state: OptimizerState, config: RunnableConfig) -> Optimize return {"context_input": retrieve_question, "documents": documents_dict} +async def selectai_completion(state: OptimizerState, config: RunnableConfig) -> OptimizerState: + """Generate answer when SelectAI enabled; modify state with response""" + selectai_prompt = state["cleaned_messages"][-1:][0].content + + logger.info("Generating SelectAI Response on %s", selectai_prompt) + sql = """ + SELECT DBMS_CLOUD_AI.GENERATE( + prompt => :query, + profile_name => :profile, + action => :action) + FROM dual + """ + binds = { + "query": selectai_prompt, + "profile": config["metadata"]["selectai"].profile, + "action": config["metadata"]["selectai"].action, + } + # Execute the SQL using the connection + db_conn = config["configurable"]["db_conn"] + try: + response = execute_sql(db_conn, sql, binds) + except Exception as ex: + logger.error("SelectAI has hit an issue: %s", ex) + response = [{sql: f"I'm sorry, I ran into an error: str({ex})"}] + # Response will be [{sql:, completion}]; return the completion + logger.debug("SelectAI Responded: %s", response) + response = list(response[0].values())[0] + + return {"messages": [AIMessage(content=response)]} + + async def stream_completion(state: OptimizerState, config: RunnableConfig) -> OptimizerState: """LiteLLM streaming wrapper""" writer = get_stream_writer() @@ -316,16 +349,17 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig) -> Op workflow.add_node("rephrase", rephrase) workflow.add_node("vs_retrieve", vs_retrieve) workflow.add_node("vs_grade", vs_grade) +workflow.add_node("selectai_completion", selectai_completion) workflow.add_node("stream_completion", stream_completion) # Start the chatbot with clean messages workflow.add_edge(START, "initialise") -# Branch to either "selectai", "vs_retrieve", or "generate_response" +# Branch to either "selectai_completion", "vs_retrieve", or "stream_completion" workflow.add_conditional_edges("initialise", use_tool) -# workflow.add_edge("selectai", "stream_completion") workflow.add_edge("vs_retrieve", "vs_grade") workflow.add_edge("vs_grade", "stream_completion") +workflow.add_edge("selectai_completion", END) # End the workflow workflow.add_edge("stream_completion", END) diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 711145df..4a10a9ef 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -60,9 +60,13 @@ async def completion_generator( user_sys_prompt = getattr(client_settings.prompts, "sys", "Basic Example") kwargs["config"]["metadata"]["sys_prompt"] = core_prompts.get_prompts(category="sys", name=user_sys_prompt) + # Add DB Conn to KWargs when needed + if client_settings.vector_search.enabled or client_settings.selectai.enabled: + db_conn = util_databases.get_client_db(client, False).connection + kwargs["config"]["configurable"]["db_conn"] = db_conn + # Setup Vector Search if client_settings.vector_search.enabled: - kwargs["config"]["configurable"]["db_conn"] = util_databases.get_client_db(client, False).connection kwargs["config"]["configurable"]["embed_client"] = util_models.get_embed_client( client_settings.vector_search.model_dump(), oci_config ) @@ -70,44 +74,12 @@ async def completion_generator( user_ctx_prompt = getattr(client_settings.prompts, "ctx", "Basic Example") kwargs["config"]["metadata"]["ctx_prompt"] = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) - # db_conn = None - # # Setup selectai - # if client_settings.selectai.enabled: - # db_conn = util_databases.get_client_db(client).connection - # util_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) - # util_selectai.set_profile( - # db_conn, client_settings.selectai.profile, "max_tokens", model["max_completion_tokens"] - # ) - - # # Setup vector_search - # embed_client, ctx_prompt = None, None - # if client_settings.vector_search.enabled: - # db_conn = util_databases.get_client_db(client).connection - # embed_client = util_models.get_client(client_settings.vector_search.model_dump(), oci_config) - - # user_ctx_prompt = getattr(client_settings.prompts, "ctx", "Basic Example") - # ctx_prompt = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) - - # kwargs = { - # "stream_mode": "custom", - # "input": {"messages": [HumanMessage(content=request.messages[0].content)]}, - # "config": RunnableConfig( - # configurable={ - # "thread_id": client, - # "ll_config": ll_client, - # "embed_client": embed_client, - # "db_conn": db_conn, - # }, - # metadata={ - # "model_id": model["model"], - # "use_history": client_settings.ll_model.chat_history, - # "vector_search": client_settings.vector_search, - # "selectai": client_settings.selectai, - # "sys_prompt": sys_prompt, - # "ctx_prompt": ctx_prompt, - # }, - # ), - # } + if client_settings.selectai.enabled: + util_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) + util_selectai.set_profile( + db_conn, client_settings.selectai.profile, "max_tokens", model["max_completion_tokens"] + ) + logger.debug("Completion Kwargs: %s", kwargs) final_response = None async for output in chatbot_graph.astream(**kwargs): @@ -119,37 +91,3 @@ async def completion_generator( yield "[stream_finished]" # This will break the Chatbot loop if call == "completions" and final_response is not None: yield final_response # This will be captured for ChatResponse - - # print(f"********** output: {output["stream"][]}") - # for chunk in output["llm"]["messages"]: - # print(f"********** chunk: {chunk}") - # print(f"********** yield: {chunk.content}") - # yield chunk.content.encode("utf-8") - - # result = await graph.ainvoke(**kwargs) - # print("\n\nFinal result:", result) - - # try: - # async for chunk in agent.astream_events(**kwargs, version="v2", stream=True): - # # The below will produce A LOT of output; uncomment when desperate - # # logger.debug("Streamed Chunk: %s", chunk) - # if chunk["event"] == "on_chat_model_stream": - # if "tools_condition" in str(chunk["metadata"]["langgraph_triggers"]): - # continue # Skip Tool Call messages - # if "vs_retrieve" in str(chunk["metadata"]["langgraph_node"]): - # continue # Skip Fake-Tool Call messages - # content = chunk["data"]["chunk"].content - # if content != "" and call == "streams": - # yield content.encode("utf-8") - # last_response = chunk["data"] - # if call == "streams": - # yield "[stream_finished]" # This will break the Chatbot loop - # elif call == "completions": - # final_response = last_response["output"]["final_response"] - # yield final_response # This will be captured for ChatResponse - # except Exception as ex: - # logger.error("An invoke exception occurred: %s", ex) - # # yield f"I'm sorry; {ex}" - # # TODO(gotsysdba) - If a message is returned; - # # format and return (this should be done in the agent) - # raise From f52e669d5a5f295731b179d8b35fd2436dfd7f93 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 00:39:21 +0100 Subject: [PATCH 13/31] Linting --- .pylintrc | 647 ++++++++++++++++++ src/client/content/chatbot.py | 9 +- src/client/content/config/tabs/databases.py | 2 +- src/client/content/config/tabs/models.py | 47 +- src/client/content/testbed.py | 23 +- src/client/content/tools/tabs/split_embed.py | 2 +- src/client/utils/st_common.py | 4 +- src/common/schema.py | 4 +- src/launch_server.py | 4 - src/server/agents/chatbot.py | 7 +- src/server/api/core/models.py | 16 +- src/server/api/utils/chat.py | 20 +- src/server/api/utils/embed.py | 6 +- src/server/api/utils/models.py | 69 +- src/server/api/utils/oci.py | 2 + src/server/api/utils/testbed.py | 44 +- src/server/api/v1/chat.py | 3 +- src/server/api/v1/embed.py | 24 +- src/server/api/v1/models.py | 31 +- src/server/api/v1/oci.py | 24 +- src/server/api/v1/selectai.py | 14 +- src/server/api/v1/testbed.py | 71 +- src/server/bootstrap/models.py | 40 +- src/server/bootstrap/oci.py | 8 +- .../server/test_endpoints_embed.py | 18 +- .../server/test_endpoints_models.py | 6 +- tests/unit/server/api/utils/models.py | 0 27 files changed, 893 insertions(+), 252 deletions(-) create mode 100644 .pylintrc delete mode 100644 tests/unit/server/api/utils/models.py diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..a0a638d1 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,647 @@ +[MAIN] + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint +# in a server-like mode. +clear-cache-post-run=no + +# Load and enable all available extensions. Use --list-extensions to see a list +# all available extensions. +#enable-all-extensions= + +# In error mode, messages with a category besides ERROR or FATAL are +# suppressed, and no reports are done by default. Error mode is compatible with +# disabling specific errors. +#errors-only= + +# Always return a 0 (non-error) status code, even if lint errors are found. +# This is primarily useful in continuous integration scripts. +#exit-zero= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-allow-list= + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +extension-pkg-whitelist= + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +fail-on= + +# Specify a score threshold under which the program will exit with error. +fail-under=10 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +#from-stdin= + +# Files or directories to be skipped. They should be base names, not paths. +ignore=CVS,.venv + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, +# it can't be used as an escape character. +ignore-paths= + +# Files or directories matching the regular expression patterns are skipped. +# The regex matches against base names, not paths. The default value ignores +# Emacs file locks +ignore-patterns=^\.# + +# List of module names for which member attributes should not be checked and +# will not be imported (useful for modules/projects where namespaces are +# manipulated during runtime and thus existing member attributes cannot be +# deduced by static analysis). It supports qualified module names, as well as +# Unix pattern matching. +ignored-modules= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Resolve imports to .pyi stubs if available. May reduce no-member messages and +# increase not-an-iterable messages. +prefer-stubs=no + +# Minimum Python version to use for version dependent checks. Will default to +# the version used to run pylint. +py-version=3.11 + +# Discover python modules and packages in the file system subtree. +recursive=no + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +source-roots= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + +# In verbose mode, extra non-checker-related info will be displayed. +#verbose= + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. If left empty, argument names will be checked with the set +# naming style. +#argument-rgx= + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +#class-attribute-rgx= + +# Naming style matching correct class constant names. +class-const-naming-style=UPPER_CASE + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +#class-const-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. If left empty, class names will be checked with the set naming style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. If left empty, function names will be checked with the set +# naming style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Regular expression matching correct type alias names. If left empty, type +# alias names will be checked with the set naming style. +#typealias-rgx= + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +#typevar-rgx= + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. If left empty, variable names will be checked with the set +# naming style. +#variable-rgx= + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + asyncSetUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make,os._exit + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +exclude-too-few-public-methods= + +# List of qualified class names to ignore when counting class parents (see +# R0901) +ignored-parents= + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of positional arguments for function / method. +max-positional-arguments=5 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when caught. +overgeneral-exceptions=builtins.BaseException,builtins.Exception + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow explicit reexports by alias from a package __init__. +allow-reexport-from-package=no + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules= + +# Output a graph (.gv or any supported image format) of external dependencies +# to the given file (report RP0402 must not be disabled). +ext-import-graph= + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be +# disabled). +import-graph= + +# Output a graph (.gv or any supported image format) of internal dependencies +# to the given file (report RP0402 must not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, +# UNDEFINED. +confidence=HIGH, + CONTROL_FLOW, + INFERENCE, + INFERENCE_FAILURE, + UNDEFINED + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then re-enable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + use-implicit-booleaness-not-comparison-to-string, + use-implicit-booleaness-not-comparison-to-zero + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable= + + +[METHOD_ARGS] + +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.api.options,requests.api.patch,requests.api.post,requests.api.put,requests.api.request + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +notes-rgx= + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit,argparse.parse_error + +# Let 'consider-using-join' be raised when the separator to join on would be +# non-empty (resulting in expected fixes of the type: ``"- " + " - +# ".join(items)``) +suggest-join-with-non-empty-separator=yes + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each +# category, as well as 'statement' which is the total number of statements +# analyzed. This score is used by the global evaluation report (RP0004). +evaluation=max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +msg-template= + +# Set the output format. Available formats are: 'text', 'parseable', +# 'colorized', 'json2' (improved json format), 'json' (old json format), msvs +# (visual studio) and 'github' (GitHub actions). You can also give a reporter +# class, e.g. mypackage.mymodule.MyReporterClass. +#output-format= + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[SIMILARITIES] + +# Comments are removed from the similarity computation +ignore-comments=yes + +# Docstrings are removed from the similarity computation +ignore-docstrings=yes + +# Imports are removed from the similarity computation +ignore-imports=yes + +# Signatures are removed from the similarity computation +ignore-signatures=yes + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. No available dictionaries : You need to install +# both the python package and the system dependency for enchant to work. +spelling-dict= + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives=fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy: + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins=no-member, + not-async-context-manager, + not-context-manager, + attribute-defined-outside-init + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,argparse.Namespace + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The maximum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx=.*[Mm]ixin + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of names allowed to shadow builtins +allowed-redefined-builtins= + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index ec276133..d495af2d 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -11,6 +11,7 @@ import inspect import json import base64 +from httpx import RemoteProtocolError import streamlit as st from streamlit import session_state as state @@ -143,13 +144,7 @@ async def main() -> None: # Stream until we hit the end then refresh to replace with history st.rerun() except Exception: - logger.error("Exception:", exc_info=1) - st.chat_message("ai").write( - """ - I'm sorry, something's gone wrong. Please try again. - If the problem persists, please raise an issue. - """ - ) + message_placeholder.markdown("An unexpected error occurred, please retry your request.") if st.button("Retry", key="reload_chatbot"): st_common.clear_state_key("user_client") st.rerun() diff --git a/src/client/content/config/tabs/databases.py b/src/client/content/config/tabs/databases.py index 2ce5ce70..cdff157a 100644 --- a/src/client/content/config/tabs/databases.py +++ b/src/client/content/config/tabs/databases.py @@ -221,7 +221,7 @@ def display_databases() -> None: column_config={ "enabled": st.column_config.CheckboxColumn(label="Enabled", help="Toggle to enable or disable") }, - use_container_width=True, + width="stretch", hide_index=True, ) if st.button("Apply SelectAI Changes", type="secondary"): diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 97fa76f4..643efa4f 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -28,7 +28,7 @@ ################################### # Functions ################################### -def clear_client_models(model_id: str) -> None: +def clear_client_models(model_provider: str, model_id: str) -> None: """Clear selected models from client settings if modified""" model_keys = [ ("ll_model", "model"), @@ -37,7 +37,7 @@ def clear_client_models(model_id: str) -> None: ("testbed", "qa_embed_model"), ] for section, key in model_keys: - if state.client_settings[section][key] == model_id: + if state.client_settings[section][key] == f"{model_provider}/{model_id}": state.client_settings[section][key] = None @@ -61,35 +61,37 @@ def get_model_providers() -> list: def create_model(model: dict) -> None: """Add either Language Model or Embed Model""" - _ = api_call.post(endpoint="v1/models", params={"id": model["id"]}, payload={"json": model}) - st.success(f"Model created: {model['id']}") + _ = api_call.post(endpoint="v1/models", payload={"json": model}) + st.success(f"Model created: {model['provider']}/{model['id']}") def patch_model(model: dict) -> None: """Update Model Configuration for either Language Models or Embed Models""" - _ = api_call.patch(endpoint=f"v1/models/{model['id']}", payload={"json": model}) + _ = api_call.patch(endpoint=f"v1/models/{model['provider']}/{model['id']}", payload={"json": model}) st.success(f"Model updated: {model['id']}") # If updated model is the set model and not enabled: unset the user settings if not model["enabled"]: - clear_client_models(model["id"]) + clear_client_models(model["provider"], model["id"]) -def delete_model(model_id: str) -> None: +def delete_model(model_provider: str, model_id: str) -> None: """Update Model Configuration for either Language Models or Embed Models""" - api_call.delete(endpoint=f"v1/models/{model_id}") - st.success(f"Model deleted: {model_id}") + api_call.delete(endpoint=f"v1/models/{model_provider}/{model_id}") + st.success(f"Model deleted: {model_provider}/{model_id}") sleep(1) # If deleted model is the set model; unset the user settings - clear_client_models(model_id) + clear_client_models(model_provider, model_id) @st.dialog("Model Configuration", width="large") -def edit_model(model_type: str, action: Literal["add", "edit"], model_id: str = None) -> None: +def edit_model( + model_type: str, action: Literal["add", "edit"], model_id: str = None, model_provider: str = None +) -> None: """Model Edit Dialog Box""" # Initialize our model request if action == "edit": model_id = urllib.parse.quote(model_id, safe="") - model = api_call.get(endpoint=f"v1/models/{model_id}") + model = api_call.get(endpoint=f"v1/models/{model_provider}/{model_id}") else: model = {"id": "unset", "type": model_type, "provider": "unset", "status": "CUSTOM"} with st.form("edit_model"): @@ -120,8 +122,8 @@ def edit_model(model_type: str, action: Literal["add", "edit"], model_id: str = "Provider URL:", help=help_text.help_dict["model_url"], key="add_model_url", - value=model.get("url", ""), - disabled=disable_for_oci + value=model.get("api_base", ""), + disabled=disable_for_oci, ) model["api_key"] = st.text_input( "API Key:", @@ -174,20 +176,14 @@ def edit_model(model_type: str, action: Literal["add", "edit"], model_id: str = button_col_format = st.columns([1.2, 1.4, 6, 1.4]) action_button, delete_button, _, cancel_button = button_col_format try: - if action == "add" and action_button.form_submit_button( - label="Add", type="primary", use_container_width=True - ): + if action == "add" and action_button.form_submit_button(label="Add", type="primary", width="stretch"): create_model(model=model) submit = True - if action == "edit" and action_button.form_submit_button( - label="Save", type="primary", use_container_width=True - ): + if action == "edit" and action_button.form_submit_button(label="Save", type="primary", width="stretch"): patch_model(model=model) submit = True - if action != "add" and delete_button.form_submit_button( - label="Delete", type="secondary", use_container_width=True - ): - delete_model(model_id=model["id"]) + if action != "add" and delete_button.form_submit_button(label="Delete", type="secondary", width="stretch"): + delete_model(model_provider=model["provider"], model_id=model["id"]) submit = True if submit: sleep(1) @@ -212,6 +208,7 @@ def render_model_rows(model_type: str) -> None: col5.markdown("​") for model in [m for m in state.model_configs if m.get("type") == model_type]: model_id = model["id"] + model_provider = model["provider"] col1.text_input( "Enabled", value=st_common.bool_to_emoji(model["enabled"]), @@ -243,7 +240,7 @@ def render_model_rows(model_type: str) -> None: "Edit", on_click=edit_model, key=f"{model_type}_{model_id}_edit", - kwargs=dict(model_type=model_type, action="edit", model_id=model_id), + kwargs=dict(model_type=model_type, action="edit", model_id=model_id, model_provider=model_provider), ) if st.button(label="Add", type="primary", key=f"add_{model_type}_model"): diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index 15f73aaf..728124aa 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -195,14 +195,14 @@ def qa_update_gui(qa_testset: list) -> None: prev_col.button( "← Previous", disabled=prev_disabled, - use_container_width=True, + width="stretch", on_click=update_record, kwargs={"direction": -1}, ) next_col.button( "Next →", disabled=next_disabled, - use_container_width=True, + width="stretch", on_click=update_record, kwargs={"direction": 1}, ) @@ -210,7 +210,7 @@ def qa_update_gui(qa_testset: list) -> None: "⚠ Delete Q&A", type="tertiary", disabled=delete_disabled, - use_container_width=True, + width="stretch", on_click=delete_record, ) st.text_area( @@ -261,7 +261,12 @@ def main() -> None: # If there is no eligible (OpenAI Compat.) Embedding Model; disable Generate Test Set gen_testset_disabled = False embed_models_enabled = st_common.enabled_models_lookup("embed") - available_embed_models = [key for key, value in embed_models_enabled.items()] + # Remove oci/cohere* models as not supported by LiteLLM + available_embed_models = [ + key + for key, value in embed_models_enabled.items() + if not (value.get("provider") == "oci" and "cohere" in value.get("id", "")) + ] if not available_embed_models: st.warning( "No OpenAI compatible embedding models are configured and/or enabled. Disabling Test Set Generation.", @@ -404,7 +409,7 @@ def main() -> None: state.running = True # Load TestSets (and Evaluations if from DB) - if col_left.button(button_text, key="load_tests", use_container_width=True, disabled=state.running): + if col_left.button(button_text, key="load_tests", width="stretch", disabled=state.running): with st.spinner("Processing Q&A... please be patient.", show_time=True): if testset_source != "Database": api_params["name"] = (state.testbed["testset_name"],) @@ -454,7 +459,7 @@ def main() -> None: "Reset", key="reset_test_framework", type="primary", - use_container_width=True, + width="stretch", on_click=reset_testset, kwargs={"cache": True}, ) @@ -462,7 +467,7 @@ def main() -> None: "⚠ Delete Test Set", key="delete_test_set", type="tertiary", - use_container_width=True, + width="stretch", disabled=not state.testbed["testset_id"], on_click=qa_delete, ) @@ -515,7 +520,7 @@ def main() -> None: view.button( "View", type="primary", - use_container_width=True, + width="stretch", on_click=evaluation_report, kwargs={"eid": evaluation_eid}, disabled=evaluation_eid is None, @@ -528,7 +533,7 @@ def main() -> None: st_common.selectai_sidebar() st_common.vector_search_sidebar() st.write("Choose a model to judge the correctness of the chatbot answer, then start evaluation.") - col_left, col_center, _ = st.columns([3, 3, 4]) + col_left, col_center, _ = st.columns([4, 3, 3]) if state.client_settings["testbed"].get("judge_model") is None: state.client_settings["testbed"]["judge_model"] = available_ll_models[0] selected_judge = state.client_settings["testbed"]["judge_model"] diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index 8704e0d9..4d647144 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -75,7 +75,7 @@ def files_data_editor(files, key): return st.data_editor( files, key=key, - use_container_width=True, + width="stretch", column_config={ "to process": st.column_config.CheckboxColumn( "in", diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index cede5159..3a9e8db1 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -42,7 +42,7 @@ def enabled_models_lookup(model_type: str) -> dict[str, dict[str, Any]]: """Create a lookup of enabled `type` models""" all_models = state_configs_lookup("model_configs", "id") enabled_models = { - id: config + f"{config.get('provider')}/{id}": config for id, config in all_models.items() if config.get("type") == model_type and config.get("enabled") is True } @@ -136,7 +136,7 @@ def history_sidebar() -> None: key="selected_ll_model_chat_history", on_change=update_client_settings("ll_model"), ) - if button_col.button("Clear", disabled=not chat_history_enable, use_container_width=True): + if button_col.button("Clear", disabled=not chat_history_enable, width="stretch"): # Clean out history try: api_call.patch(endpoint="v1/chat/history") diff --git a/src/common/schema.py b/src/common/schema.py index e2296f03..04571669 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -37,7 +37,6 @@ "mistralai", "ollama", "openai", - "openai_compatible", "perplexity", "xai", ] @@ -377,7 +376,7 @@ class ChatRequest(LanguageModelParameters): "json_schema_extra": { "examples": [ { - "model": "gpt-4o-mini", + "model": "openai/gpt-4o-mini", "messages": [{"role": "user", "content": "Hello, how are you?"}], "response_format": {"type": "text"}, "temperature": 1, @@ -433,6 +432,7 @@ class EvaluationReport(Evaluation): DatabaseNameType = Database.__annotations__["name"] VectorStoreTableType = DatabaseVectorStorage.__annotations__["vector_store"] ModelIdType = Model.__annotations__["id"] +ModelProviderType = Model.__annotations__["provider"] ModelTypeType = Model.__annotations__["type"] ModelEnabledType = ModelAccess.__annotations__["enabled"] OCIProfileType = OracleCloudSettings.__annotations__["auth_profile"] diff --git a/src/launch_server.py b/src/launch_server.py index b9c02194..bc733be4 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -10,10 +10,6 @@ # Set OS Environment (Don't move their position to reflect on imports) os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" -os.environ["LITELLM_DISABLE_SPEND_LOGS"] = "True" -os.environ["LITELLM_DISABLE_SPEND_UPDATES"] = "True" -os.environ["LITELLM_DISABLE_END_USER_COST_TRACKING"] = "True" -os.environ["LITELLM_DROP_PARAMS"] = "True" os.environ["GSK_DISABLE_SENTRY"] = "true" os.environ["GSK_DISABLE_ANALYTICS"] = "true" os.environ["USER_AGENT"] = "ai-optimizer" diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index fa67816e..ee7128ef 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -154,6 +154,7 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt logger.info("Grading Vector Search Response using %i retrieved documents", len(state["documents"])) # Initialise documents as relevant relevant = "yes" + documents_dict = document_formatter(state["documents"]) if config["metadata"]["vector_search"].grading and state.get("documents"): grade_template = """ You are a Grader assessing the relevance of retrieved text to the user's input. @@ -172,7 +173,6 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt template=grade_template, input_variables=["question", "documents"], ) - documents_dict = document_formatter(state["documents"]) question = state["context_input"] formatted_prompt = grade_template.format(question=question, documents=documents_dict) logger.debug("Grading Prompt: %s", formatted_prompt) @@ -337,9 +337,12 @@ async def stream_completion(state: OptimizerState, config: RunnableConfig) -> Op final_response = last_chunk.model_dump() writer({"completion": final_response}) + except APIConnectionError as ex: + logger.error(ex) + full_text = "I'm not able to contact the model API; please validate its configuration/availability." except Exception as ex: logger.error(ex) - full_text = f"I'm sorry, a completion problem occurred: {str(ex).split('Traceback', 1)[0]}" + full_text = f"I'm sorry, an unknown completion problem occurred: {str(ex).split('Traceback', 1)[0]}" return {"messages": [AIMessage(content=full_text)]} diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index 77c0820c..f24efcbb 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -7,7 +7,7 @@ from server.api.core import bootstrap -from common.schema import Model, ModelIdType, ModelTypeType +from common.schema import Model, ModelIdType, ModelProviderType, ModelTypeType from common.functions import is_url_accessible import common.logging_config as logging_config @@ -37,6 +37,7 @@ class UnknownModelError(ValueError): # Functions ##################################################### def get_model( + model_provider: Optional[ModelProviderType] = None, model_id: Optional[ModelIdType] = None, model_type: Optional[ModelTypeType] = None, include_disabled: bool = True, @@ -51,6 +52,7 @@ def get_model( for model in model_objects if (model_id is None or model.id == model_id) and (model_type is None or model.type == model_type) + and (model_provider is None or model.provider == model_provider) and (include_disabled or model.enabled) ] logger.debug("%i models after filtering", len(model_filtered)) @@ -70,18 +72,20 @@ def create_model(model: Model, check_url: bool = True) -> Model: """Create a new Model definition""" model_objects = bootstrap.MODEL_OBJECTS - if any(d.id == model.id for d in model_objects): + try: + _ = get_model(model_id=model.id, model_provider=model.provider, model_type=model.type) raise ExistsModelError(f"Model: {model.id} already exists.") + except UnknownModelError: + pass if check_url and model.api_base and not is_url_accessible(model.api_base)[0]: model.enabled = False model_objects.append(model) + return get_model(model_id=model.id, model_provider=model.provider, model_type=model.type) - return get_model(model_id=model.id, model_type=model.type) - -def delete_model(model_id: ModelIdType) -> None: +def delete_model(model_provider: ModelProviderType, model_id: ModelIdType) -> None: """Remove model from model objects""" model_objects = bootstrap.MODEL_OBJECTS - bootstrap.MODEL_OBJECTS = [model for model in model_objects if model.id != model_id] + bootstrap.MODEL_OBJECTS = [m for m in model_objects if (m.id, m.provider) != (model_id, model_provider)] diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 4a10a9ef..9aedea1e 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -2,9 +2,8 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore astream selectai +# spell-checker:ignore astream selectai litellm -import time from typing import Literal, AsyncGenerator from langchain_core.messages import HumanMessage @@ -13,10 +12,10 @@ import server.api.core.settings as core_settings import server.api.core.oci as core_oci import server.api.core.prompts as core_prompts -import server.api.utils.models as util_models -import server.api.utils.databases as util_databases +import server.api.utils.models as utils_models +import server.api.utils.databases as utils_databases from server.agents.chatbot import chatbot_graph -import server.api.utils.selectai as util_selectai +import server.api.utils.selectai as utils_selectai import common.schema as schema import common.logging_config as logging_config @@ -28,6 +27,7 @@ async def completion_generator( client: schema.ClientIdType, request: schema.ChatRequest, call: Literal["completions", "streams"] ) -> AsyncGenerator[str, None]: """Generate a completion from agent, stream the results""" + client_settings = core_settings.get_client_settings(client) model = request.model_dump() logger.debug("Settings: %s", client_settings) @@ -40,7 +40,7 @@ async def completion_generator( oci_config = core_oci.get_oci(client=client) # Setup Client Model - ll_config = util_models.get_litellm_config(model, oci_config) + ll_config = utils_models.get_litellm_config(model, oci_config) # Start to establish our LangGraph Args kwargs = { @@ -62,12 +62,12 @@ async def completion_generator( # Add DB Conn to KWargs when needed if client_settings.vector_search.enabled or client_settings.selectai.enabled: - db_conn = util_databases.get_client_db(client, False).connection + db_conn = utils_databases.get_client_db(client, False).connection kwargs["config"]["configurable"]["db_conn"] = db_conn # Setup Vector Search if client_settings.vector_search.enabled: - kwargs["config"]["configurable"]["embed_client"] = util_models.get_embed_client( + kwargs["config"]["configurable"]["embed_client"] = utils_models.get_client_embed( client_settings.vector_search.model_dump(), oci_config ) # Get Context Prompt @@ -75,8 +75,8 @@ async def completion_generator( kwargs["config"]["metadata"]["ctx_prompt"] = core_prompts.get_prompts(category="ctx", name=user_ctx_prompt) if client_settings.selectai.enabled: - util_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) - util_selectai.set_profile( + utils_selectai.set_profile(db_conn, client_settings.selectai.profile, "temperature", model["temperature"]) + utils_selectai.set_profile( db_conn, client_settings.selectai.profile, "max_tokens", model["max_completion_tokens"] ) diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index 09f8315d..c84dded1 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -25,7 +25,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_text_splitters import HTMLHeaderTextSplitter, CharacterTextSplitter -import server.api.utils.databases as util_databases +import server.api.utils.databases as utils_databases import server.api.core.databases as core_databases import common.functions @@ -303,7 +303,7 @@ def json_to_doc(file: str): # Establish a dedicated connection to the database db_conn = core_databases.connect(db_details) # This is to allow re-using an existing VS; will merge this over later - util_databases.drop_vs(db_conn, vector_store_tmp.vector_store) + utils_databases.drop_vs(db_conn, vector_store_tmp.vector_store) logger.info("Establishing initial vector store") logger.debug("Embed Client: %s", embed_client) vs_tmp = OracleVS( @@ -353,7 +353,7 @@ def json_to_doc(file: str): """ logger.info("Merging %s into %s", vector_store_tmp.vector_store, vector_store.vector_store) core_databases.execute_sql(db_conn, merge_sql) - util_databases.drop_vs(db_conn, vector_store_tmp.vector_store) + utils_databases.drop_vs(db_conn, vector_store_tmp.vector_store) # Build the Index logger.info("Creating index on: %s", vector_store.vector_store) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index ad69c082..68cf4173 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -12,7 +12,7 @@ from langchain_community.embeddings.oci_generative_ai import OCIGenAIEmbeddings from langchain_core.embeddings.embeddings import Embeddings -import server.api.utils.oci as util_oci +import server.api.utils.oci as utils_oci import server.api.core.models as core_models from common.functions import is_url_accessible @@ -22,10 +22,10 @@ logger = logging_config.logging.getLogger("api.utils.models") -def update_model(model_id: schema.ModelIdType, payload: schema.Model) -> schema.Model: +def update(payload: schema.Model) -> schema.Model: """Update an existing Model definition""" - model_upd = core_models.get_model(model_id=model_id) + model_upd = core_models.get_model(model_provider=payload.provider, model_id=payload.id) if payload.enabled and not is_url_accessible(model_upd.api_base)[0]: model_upd.enabled = False raise core_models.URLUnreachableError("Model: Unable to update. API URL is inaccessible.") @@ -39,23 +39,18 @@ def update_model(model_id: schema.ModelIdType, payload: schema.Model) -> schema. return model_upd -def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model]: +def create_genai(config: schema.OracleCloudSettings) -> list[schema.Model]: """Create and enable all GenAI models in the configured region""" - region_models = util_oci.get_genai_models(config, regional=True) + region_models = utils_oci.get_genai_models(config, regional=True) if region_models: # Delete previously configured GenAI Models all_models = core_models.get_model() for model in all_models: if model.provider == "oci": - core_models.delete_model(model.id) + core_models.delete_model(model_provider=model.provider, model_id=model.id) genai_models = [] for model in region_models: - if model["vendor"] == "cohere": - # Note that we can enable this if the GenAI endpoint supports OpenAI compat - # https://docs.cohere.com/docs/compatibility-api - logger.info("Skipping %s; no support for OCI GenAI cohere models", model["model_name"]) - continue model_dict = {} model_dict["provider"] = "oci" if "CHAT" in model["capabilities"]: @@ -80,38 +75,40 @@ def create_genai_models(config: schema.OracleCloudSettings) -> list[schema.Model return genai_models -def _get_full_config(model_config: dict, oci_config: schema.OracleCloudSettings = None) -> tuple[dict, str]: +def _get_full_config(model_config: dict, oci_config: schema.OracleCloudSettings = None) -> dict: logger.debug("Model Client: %s; OCI Config: %s", model_config, oci_config) + model_provider, model_id = model_config["model"].split("/", 1) try: defined_model = core_models.get_model( - model_id=model_config["model"], + model_provider=model_provider, + model_id=model_id, include_disabled=False, ).model_dump() - except core_models.UnknownModelError: - return None + except core_models.UnknownModelError as ex: + raise ex # Merge configurations, skipping None values full_model_config = {**defined_model, **{k: v for k, v in model_config.items() if v is not None}} provider = full_model_config.pop("provider") - provider = "openai" if provider == "openai_compatible" else provider return full_model_config, provider -def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSettings = None) -> dict: - """Establish client""" +def get_litellm_config( + model_config: dict, oci_config: schema.OracleCloudSettings = None, giskard: bool = False +) -> dict: + """Establish LiteLLM client""" full_model_config, provider = _get_full_config(model_config, oci_config) - model_name = f"{provider}/{full_model_config['id']}" # Get supported parameters and initialize config - supported_params = get_supported_openai_params(model=model_name) + supported_params = get_supported_openai_params(model=model_config["model"]) litellm_config = { k: full_model_config[k] for k in supported_params if k in full_model_config and full_model_config[k] is not None } - if "cohere" in model_name: + if "cohere" in model_config["model"]: # Ensure we use the OpenAI compatible endpoint parsed = urlparse(full_model_config.get("api_base")) scheme = parsed.scheme or "https" @@ -119,11 +116,13 @@ def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSetting # Always force the path path = "/compatibility/v1" full_model_config["api_base"] = f"{scheme}://{netloc}{path}" - if "xai" in model_name: + if "xai" in model_config["model"]: litellm_config.pop("presence_penalty", None) litellm_config.pop("frequency_penalty", None) - litellm_config.update({"model": model_name, "api_base": full_model_config.get("api_base"), "drop_params": True}) + litellm_config.update( + {"model": model_config["model"], "api_base": full_model_config.get("api_base"), "drop_params": True} + ) if provider == "oci": litellm_config.update( @@ -136,29 +135,35 @@ def get_litellm_config(model_config: dict, oci_config: schema.OracleCloudSetting "oci_compartment_id": oci_config.genai_compartment_id, } ) + + if giskard: + litellm_config.pop("model", None) + litellm_config.pop("temperature", None) + logger.debug("LiteLLM Config: %s", litellm_config) return litellm_config -def get_embed_client(model_config: dict, oci_config: schema.OracleCloudSettings) -> Embeddings: +def get_client_embed(model_config: dict, oci_config: schema.OracleCloudSettings) -> Embeddings: """Retrieve embedding model client""" full_model_config, provider = _get_full_config(model_config, oci_config) client = None - if provider != "oci": + if provider == "oci": + client = OCIGenAIEmbeddings( + model_id=full_model_config["id"], + client=utils_oci.init_genai_client(oci_config), + compartment_id=oci_config.genai_compartment_id, + ) + else: kwargs = { - "provider": "openai" if provider == "openai_compatible" else provider, + "provider": provider, "model": full_model_config["id"], "base_url": full_model_config.get("api_base"), } if full_model_config.get("api_key"): # only add if set kwargs["api_key"] = full_model_config["api_key"] client = init_embeddings(**kwargs) - else: - client = OCIGenAIEmbeddings( - model_id=full_model_config["id"], - client=util_oci.init_genai_client(oci_config), - compartment_id=oci_config.genai_compartment_id, - ) + return client diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 6fe9ec47..d49d8d96 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -165,6 +165,8 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis # Build our list of models for model in response.data.items: + if model.vendor == "cohere" and "TEXT_EMBEDDINGS" not in model.capabilities: + continue genai_models.append( { "region": region["region_name"], diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index 539da87c..9e814244 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -19,6 +19,7 @@ from giskard.rag.question_generators import simple_questions, complex_questions import server.api.core.databases as core_databases +import server.api.utils.models as utils_models import common.schema as schema import common.logging_config as logging_config @@ -235,40 +236,19 @@ def load_and_split(eval_file, chunk_size=2048): def build_knowledge_base( - text_nodes: str, questions: int, ll_model: schema.Model, embed_model: schema.Model + text_nodes: str, questions: int, ll_model: str, embed_model: str, oci_config: schema.OciSettings ) -> QATestset: """Establish a temporary Knowledge Base""" - - def configure_and_set_model(client_model): - """Configure and set Model for TestSet Generation (uses litellm)""" - model_id, disable_structured_output, params = None, False, None - if client_model.provider == "openai_compatible": - model_id, params = ( - f"openai/{client_model.id}", - {"api_base": client_model.url, "api_key": client_model.api_key or "api_compat"}, - ) - elif client_model.provider == "ollama": - model_id, disable_structured_output, params = ( - f"ollama/{client_model.id}", - True, - {"api_base": client_model.url}, - ) - elif client_model.provider == "perplexity": - model_id, params = f"perplexity/{client_model.id}", {"api_key": client_model.api_key} - else: - model_id, params = f"openai/{client_model.id}", {"api_key": client_model.api_key} - - if client_model.type == "ll": - logger.debug("KnowledgeBase LL: %s (%s)", model_id, params) - set_llm_model(model_id, disable_structured_output, **params) - else: - logger.debug("KnowledgeBase Embed: %s (%s)", model_id, params) - set_embedding_model(model_id, **params) - logger.info("KnowledgeBase creation starting...") logger.info("LL Model: %s; Embedding: %s", ll_model, embed_model) - configure_and_set_model(ll_model) - configure_and_set_model(embed_model) + + # Setup models, uses LiteLLM + ll_model_config = utils_models.get_litellm_config( + model_config={"model": ll_model}, oci_config=oci_config, giskard=True + ) + set_llm_model(llm_model=ll_model, **ll_model_config) + embed_model_config = utils_models.get_litellm_config(model_config={"model": embed_model}, giskard=True) + set_embedding_model(model=embed_model, **embed_model_config) knowledge_base_df = pd.DataFrame([node.text for node in text_nodes], columns=["text"]) knowledge_base = KnowledgeBase(data=knowledge_base_df) @@ -334,8 +314,8 @@ def clean(orig_html): "report": full_report.to_dict(), "correct_by_topic": by_topic.to_dict(), "failures": failures.to_dict(), - #"html_report": clean(html_report), #CDB - "html_report": '' + # "html_report": clean(html_report), #CDB + "html_report": "", } logger.debug("Evaluation Results: %s", evaluation_results) evaluation = schema.EvaluationReport(**evaluation_results) diff --git a/src/server/api/v1/chat.py b/src/server/api/v1/chat.py index 47bc5320..9fc6c246 100644 --- a/src/server/api/v1/chat.py +++ b/src/server/api/v1/chat.py @@ -4,7 +4,7 @@ """ # spell-checker:ignore selectai litellm -from fastapi import APIRouter, Header +from fastapi import APIRouter, Header, HTTPException from fastapi.responses import StreamingResponse from litellm import ModelResponse @@ -59,7 +59,6 @@ async def chat_stream( media_type="application/octet-stream", ) - @auth.patch( "/history", description="Delete Chat History", diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index e3a240ed..2398f9b0 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -17,9 +17,9 @@ import server.api.core.databases as core_databases import server.api.core.oci as core_oci -import server.api.utils.databases as util_databases -import server.api.utils.embed as util_embed -import server.api.utils.models as util_models +import server.api.utils.databases as utils_databases +import server.api.utils.embed as utils_embed +import server.api.utils.models as utils_models import common.functions as functions import common.schema as schema @@ -41,9 +41,9 @@ async def embed_drop_vs( """Drop Vector Storage""" logger.debug("Received %s embed_drop_vs: %s", client, vs) try: - client_db = util_databases.get_client_db(client) + client_db = utils_databases.get_client_db(client) db_conn = core_databases.connect(client_db) - util_databases.drop_vs(db_conn, vs) + utils_databases.drop_vs(db_conn, vs) except core_databases.DbException as ex: raise HTTPException(status_code=400, detail=f"Embed: {str(ex)}.") from ex return JSONResponse(status_code=200, content={"message": f"Vector Store: {vs} dropped."}) @@ -59,7 +59,7 @@ async def store_web_file( ) -> Response: """Store contents from a web URL""" logger.debug("Received store_web_file - request: %s", request) - temp_directory = util_embed.get_temp_directory(client, "embedding") + temp_directory = utils_embed.get_temp_directory(client, "embedding") # Save the file temporarily for url in request: @@ -96,7 +96,7 @@ async def store_local_file( ) -> Response: """Store contents from a local file uploaded to streamlit""" logger.debug("Received store_local_file - files: %s", files) - temp_directory = util_embed.get_temp_directory(client, "embedding") + temp_directory = utils_embed.get_temp_directory(client, "embedding") for file in files: filename = temp_directory / file.filename file_content = await file.read() @@ -119,7 +119,7 @@ async def split_embed( """Perform Split and Embed""" logger.debug("Received split_embed - rate_limit: %i; request: %s", rate_limit, request) oci_config = core_oci.get_oci(client=client) - temp_directory = util_embed.get_temp_directory(client, "embedding") + temp_directory = utils_embed.get_temp_directory(client, "embedding") try: files = [f for f in temp_directory.iterdir() if f.is_file()] @@ -135,7 +135,7 @@ async def split_embed( detail=f"Embed: Client {client} no files found in folder.", ) try: - split_docos, _ = util_embed.load_and_split_documents( + split_docos, _ = utils_embed.load_and_split_documents( files, request.model, request.chunk_size, @@ -144,14 +144,14 @@ async def split_embed( output_dir=None, ) - embed_client = util_models.get_client({"model": request.model, "enabled": True}, oci_config) + embed_client = utils_models.get_client_embed({"model": request.model, "enabled": True}, oci_config) # Calculate and set the vector_store name using get_vs_table request.vector_store, _ = functions.get_vs_table(**request.model_dump(exclude={"database", "vector_store"})) - util_embed.populate_vs( + utils_embed.populate_vs( vector_store=request, - db_details=util_databases.get_client_db(client), + db_details=utils_databases.get_client_db(client), embed_client=embed_client, input_data=split_docos, rate_limit=rate_limit, diff --git a/src/server/api/v1/models.py b/src/server/api/v1/models.py index fa6fd77f..9288fd84 100644 --- a/src/server/api/v1/models.py +++ b/src/server/api/v1/models.py @@ -9,7 +9,7 @@ from fastapi.responses import JSONResponse import server.api.core.models as core_models -import server.api.utils.models as util_models +import server.api.utils.models as utils_models import common.schema as schema import common.logging_config as logging_config @@ -47,18 +47,19 @@ async def models_list( @auth.get( - "/{model_id:path}", - description="Get a single model", + "/{model_provider}/{model_id:path}", + description="Get a single model (provider/name)", response_model=schema.Model, ) async def models_get( + model_provider: schema.ModelProviderType, model_id: schema.ModelIdType, ) -> schema.Model: """List a specific model""" - logger.debug("Received models_get - model_id: %s", model_id) + logger.debug("Received models_get - model: %s/%s", model_provider, model_id) try: - models_ret = core_models.get_model(model_id=model_id) + models_ret = core_models.get_model(model_provider=model_provider, model_id=model_id) except core_models.UnknownModelError as ex: raise HTTPException(status_code=404, detail=str(ex)) from ex @@ -66,18 +67,15 @@ async def models_get( @auth.patch( - "/{model_id:path}", + "/{model_provider}/{model_id:path}", description="Update a model", response_model=schema.Model, ) -async def models_update( - model_id: schema.ModelIdType, - payload: schema.Model, -) -> schema.Model: +async def models_update(payload: schema.Model) -> schema.Model: """Update a model""" - logger.debug("Received models_update - model_id: %s; payload: %s", model_id, payload) + logger.debug("Received models_update - payload: %s", payload) try: - return util_models.update_model(model_id=model_id, payload=payload) + return utils_models.update(payload=payload) except core_models.UnknownModelError as ex: raise HTTPException(status_code=404, detail=str(ex)) from ex except core_models.URLUnreachableError as ex: @@ -98,13 +96,14 @@ async def models_create( @auth.delete( - "/{model_id:path}", + "/{model_provider}/{model_id:path}", description="Delete a model", ) async def models_delete( + model_provider: schema.ModelProviderType, model_id: schema.ModelIdType, ) -> JSONResponse: """Delete a model""" - logger.debug("Received models_delete - model_id: %s", model_id) - core_models.delete_model(model_id) - return JSONResponse(status_code=200, content={"message": f"Model: {model_id} deleted."}) + logger.debug("Received models_delete - model: %s/%s", model_provider, model_id) + core_models.delete_model(model_provider=model_provider, model_id=model_id) + return JSONResponse(status_code=200, content={"message": f"Model: {model_provider}/{model_id} deleted."}) diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 41b7e706..e7f8d91a 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -8,9 +8,9 @@ from fastapi.responses import JSONResponse import server.api.core.oci as core_oci -import server.api.utils.embed as util_embed -import server.api.utils.oci as util_oci -import server.api.utils.models as util_models +import server.api.utils.embed as utils_embed +import server.api.utils.oci as utils_oci +import server.api.utils.models as utils_models import common.schema as schema import common.logging_config as logging_config @@ -62,7 +62,7 @@ async def oci_list_regions( logger.debug("Received oci_list_regions - auth_profile: %s", auth_profile) try: oci_config = await oci_get(auth_profile=auth_profile) - regions = util_oci.get_regions(oci_config) + regions = utils_oci.get_regions(oci_config) return regions except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -80,7 +80,7 @@ async def oci_list_genai( logger.debug("Received oci_list_regions - auth_profile: %s", auth_profile) try: oci_config = await oci_get(auth_profile=auth_profile) - all_models = util_oci.get_genai_models(oci_config, regional=False) + all_models = utils_oci.get_genai_models(oci_config, regional=False) return all_models except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -98,7 +98,7 @@ async def oci_list_compartments( logger.debug("Received oci_list_compartments - auth_profile: %s", auth_profile) try: oci_config = await oci_get(auth_profile=auth_profile) - compartments = util_oci.get_compartments(oci_config) + compartments = utils_oci.get_compartments(oci_config) return compartments except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -118,7 +118,7 @@ async def oci_list_buckets( try: compartment_obj = schema.OracleResource(ocid=compartment_ocid) oci_config = await oci_get(auth_profile=auth_profile) - buckets = util_oci.get_buckets(compartment_obj.ocid, oci_config) + buckets = utils_oci.get_buckets(compartment_obj.ocid, oci_config) return buckets except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -137,7 +137,7 @@ async def oci_list_bucket_objects( logger.debug("Received oci_list_bucket_objects - auth_profile: %s; bucket_name: %s", auth_profile, bucket_name) try: oci_config = await oci_get(auth_profile=auth_profile) - objects = util_oci.get_bucket_objects(bucket_name, oci_config) + objects = utils_oci.get_bucket_objects(bucket_name, oci_config) return objects except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -158,7 +158,7 @@ async def oci_profile_update( oci_config = await oci_get(auth_profile=auth_profile) try: - namespace = util_oci.get_namespace(payload) + namespace = utils_oci.get_namespace(payload) oci_config.namespace = namespace for key, value in payload.model_dump().items(): if value not in ("", None): @@ -192,9 +192,9 @@ async def oci_download_objects( ) oci_config = await oci_get(auth_profile=auth_profile) # Files should be placed in the embedding folder - temp_directory = util_embed.get_temp_directory(client, "embedding") + temp_directory = utils_embed.get_temp_directory(client, "embedding") for object_name in request: - util_oci.get_object(temp_directory, object_name, bucket_name, oci_config) + utils_oci.get_object(temp_directory, object_name, bucket_name, oci_config) downloaded_files = [f.name for f in temp_directory.iterdir() if f.is_file()] return JSONResponse(status_code=200, content=downloaded_files) @@ -212,7 +212,7 @@ async def oci_create_genai_models( logger.debug("Received oci_create_genai_models - auth_profile: %s", auth_profile) try: oci_config = await oci_get(auth_profile=auth_profile) - enabled_models = util_models.create_genai_models(oci_config) + enabled_models = utils_models.create_genai(oci_config) return enabled_models except core_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex diff --git a/src/server/api/v1/selectai.py b/src/server/api/v1/selectai.py index fe3c06d6..b06b4432 100644 --- a/src/server/api/v1/selectai.py +++ b/src/server/api/v1/selectai.py @@ -9,8 +9,8 @@ from fastapi import APIRouter, Header import server.api.core.settings as core_settings -import server.api.utils.databases as util_databases -import server.api.utils.selectai as util_selectai +import server.api.utils.databases as utils_databases +import server.api.utils.selectai as utils_selectai import common.schema as schema import common.logging_config as logging_config @@ -30,8 +30,8 @@ async def selectai_get_objects( ) -> list[schema.DatabaseSelectAIObjects]: """Get DatabaseSelectAIObjects""" client_settings = core_settings.get_client_settings(client) - db_conn = util_databases.get_client_db(client).connection - select_ai_objects = util_selectai.get_objects(db_conn, client_settings.selectai.profile) + db_conn = utils_databases.get_client_db(client).connection + select_ai_objects = utils_selectai.get_objects(db_conn, client_settings.selectai.profile) return select_ai_objects @@ -48,6 +48,6 @@ async def selectai_update_objects( logger.debug("Received selectai_update - payload: %s", payload) client_settings = core_settings.get_client_settings(client) object_list = json.dumps([obj.model_dump(include={"owner", "name"}) for obj in payload]) - db_conn = util_databases.get_client_db(client).connection - util_selectai.set_profile(db_conn, client_settings.selectai.profile, "object_list", object_list) - return util_selectai.get_objects(db_conn, client_settings.selectai.profile) + db_conn = utils_databases.get_client_db(client).connection + utils_selectai.set_profile(db_conn, client_settings.selectai.profile, "object_list", object_list) + return utils_selectai.get_objects(db_conn, client_settings.selectai.profile) diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 0c8dcba6..78ab53ef 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -12,18 +12,18 @@ import json from typing import Optional from giskard.rag import evaluate, QATestset +from giskard.llm import set_llm_model from fastapi import APIRouter, HTTPException, Header, UploadFile from fastapi.responses import JSONResponse import litellm from langchain_core.messages import ChatMessage -import server.api.core.models as core_models import server.api.core.settings as core_settings import server.api.core.oci as core_oci -import server.api.utils.embed as util_embed -import server.api.utils.testbed as util_testbed -import server.api.utils.databases as util_databases -import server.api.utils.models as util_models +import server.api.utils.embed as utils_embed +import server.api.utils.testbed as utils_testbed +import server.api.utils.databases as utils_databases +import server.api.utils.models as utils_models from server.api.v1 import chat @@ -44,7 +44,7 @@ async def testbed_testsets( client: schema.ClientIdType = Header(default="server"), ) -> list[schema.TestSets]: """Get a list of stored TestSets, create TestSet objects if they don't exist""" - testsets = util_testbed.get_testsets(db_conn=util_databases.get_client_db(client).connection) + testsets = utils_testbed.get_testsets(db_conn=utils_databases.get_client_db(client).connection) return testsets @@ -58,8 +58,8 @@ async def testbed_evaluations( client: schema.ClientIdType = Header(default="server"), ) -> list[schema.Evaluation]: """Get Evaluations""" - evaluations = util_testbed.get_evaluations( - db_conn=util_databases.get_client_db(client).connection, tid=tid.upper() + evaluations = utils_testbed.get_evaluations( + db_conn=utils_databases.get_client_db(client).connection, tid=tid.upper() ) return evaluations @@ -74,7 +74,9 @@ async def testbed_evaluation( client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: """Get Evaluations""" - evaluation = util_testbed.process_report(db_conn=util_databases.get_client_db(client).connection, eid=eid.upper()) + evaluation = utils_testbed.process_report( + db_conn=utils_databases.get_client_db(client).connection, eid=eid.upper() + ) return evaluation @@ -88,7 +90,7 @@ async def testbed_testset_qa( client: schema.ClientIdType = Header(default="server"), ) -> schema.TestSetQA: """Get TestSet Q&A""" - return util_testbed.get_testset_qa(db_conn=util_databases.get_client_db(client).connection, tid=tid.upper()) + return utils_testbed.get_testset_qa(db_conn=utils_databases.get_client_db(client).connection, tid=tid.upper()) @auth.delete( @@ -100,7 +102,7 @@ async def testbed_delete_testset( client: schema.ClientIdType = Header(default="server"), ) -> JSONResponse: """Delete TestSet""" - util_testbed.delete_qa(util_databases.get_client_db(client).connection, tid.upper()) + utils_testbed.delete_qa(utils_databases.get_client_db(client).connection, tid.upper()) return JSONResponse(status_code=200, content={"message": f"TestSet: {tid} deleted."}) @@ -117,12 +119,12 @@ async def testbed_upsert_testsets( ) -> schema.TestSetQA: """Update stored TestSet data""" created = datetime.now().isoformat() - db_conn = util_databases.get_client_db(client).connection + db_conn = utils_databases.get_client_db(client).connection try: for file in files: file_content = await file.read() - content = util_testbed.jsonl_to_json_content(file_content) - db_id = util_testbed.upsert_qa(db_conn, name, created, content, tid) + content = utils_testbed.jsonl_to_json_content(file_content) + db_id = utils_testbed.upsert_qa(db_conn, name, created, content, tid) db_conn.commit() except Exception as ex: logger.error("An exception occurred: %s", ex) @@ -140,16 +142,19 @@ async def testbed_upsert_testsets( async def testbed_generate_qa( files: list[UploadFile], name: schema.TestSetsNameType, - ll_model: schema.ModelIdType = None, - embed_model: schema.ModelIdType = None, + ll_model: str, + embed_model: str, questions: int = 2, client: schema.ClientIdType = Header(default="server"), ) -> schema.TestSetQA: """Retrieve contents from a local file uploaded and generate Q&A""" - # Setup Models - giskard_ll_model = core_models.get_model(model_id=ll_model, model_type="ll") - giskard_embed_model = core_models.get_model(model_id=embed_model, model_type="embed") - temp_directory = util_embed.get_temp_directory(client, "testbed") + # Get the Model Configuration + try: + oci_config = core_oci.get_oci(client) + except ValueError as ex: + raise HTTPException(status_code=400, detail=str(ex)) from ex + + temp_directory = utils_embed.get_temp_directory(client, "testbed") full_testsets = temp_directory / "all_testsets.jsonl" for file in files: @@ -162,8 +167,8 @@ async def testbed_generate_qa( file.write(file_content) # Process file for knowledge base - text_nodes = util_testbed.load_and_split(filename) - test_set = util_testbed.build_knowledge_base(text_nodes, questions, giskard_ll_model, giskard_embed_model) + text_nodes = utils_testbed.load_and_split(filename) + test_set = utils_testbed.build_knowledge_base(text_nodes, questions, ll_model, embed_model, oci_config) # Save test set test_set_filename = temp_directory / f"{name}.jsonl" test_set.save(test_set_filename) @@ -195,9 +200,9 @@ async def testbed_generate_qa( description="Evaluate Q&A Test Set.", response_model=schema.EvaluationReport, ) -def testbed_evaluate_qa( +def testbed_evaluate( tid: schema.TestSetsIdType, - judge: schema.ModelIdType, + judge: str, client: schema.ClientIdType = Header(default="server"), ) -> schema.EvaluationReport: """Run evaluate against a testset""" @@ -208,7 +213,7 @@ def get_answer(question: str): messages=[ChatMessage(role="human", content=question)], ) ai_response = asyncio.run(chat.chat_post(client=client, request=request)) - return ai_response.choices[0].message.content + return ai_response["choices"][0]["message"]["content"] evaluated = datetime.now().isoformat() client_settings = core_settings.get_client_settings(client) @@ -217,10 +222,10 @@ def get_answer(question: str): # Change Grade vector_search client_settings.vector_search.grading = False - db_conn = util_databases.get_client_db(client).connection - testset = util_testbed.get_testset_qa(db_conn=db_conn, tid=tid.upper()) + db_conn = utils_databases.get_client_db(client).connection + testset = utils_testbed.get_testset_qa(db_conn=db_conn, tid=tid.upper()) qa_test = "\n".join(json.dumps(item) for item in testset.qa_data) - temp_directory = util_embed.get_temp_directory(client, "testbed") + temp_directory = utils_embed.get_temp_directory(client, "testbed") with open(temp_directory / f"{tid}_output.txt", "w", encoding="utf-8") as file: file.write(qa_test) @@ -229,10 +234,12 @@ def get_answer(question: str): # Setup Judge Model logger.debug("Starting evaluation with Judge: %s", judge) oci_config = core_oci.get_oci(client) - judge_client = util_models.get_client({"model": judge}, oci_config, True) + + judge_config = utils_models.get_litellm_config(model_config={"model": judge}, oci_config=oci_config, giskard=True) + set_llm_model(llm_model=judge, **judge_config) try: # report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=[correctness_metric]) #CDB - report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=None) # CDB + report = evaluate(get_answer, testset=loaded_testset, metrics=None) # CDB except KeyError as ex: if str(ex) == "'correctness'": @@ -240,7 +247,7 @@ def get_answer(question: str): logger.debug("Ending evaluation with Judge: %s", judge) - eid = util_testbed.insert_evaluation( + eid = utils_testbed.insert_evaluation( db_conn=db_conn, tid=tid, evaluated=evaluated, @@ -251,4 +258,4 @@ def get_answer(question: str): db_conn.commit() shutil.rmtree(temp_directory) - return util_testbed.process_report(db_conn=db_conn, eid=eid) + return utils_testbed.process_report(db_conn=db_conn, eid=eid) diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 5ed65389..76b65463 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -77,7 +77,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "id": "phi-4", "enabled": False, "type": "ll", - "provider": "openai_compatible", + "provider": "huggingface", "api_key": "", "api_base": "http://localhost:1234/v1", "context_length": 131072, @@ -123,7 +123,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "id": "text-embedding-3-small", "enabled": os.getenv("OPENAI_API_KEY") is not None, "type": "embed", - "provider": "openai_compatible", + "provider": "openai", "api_base": "https://api.openai.com/v1", "api_key": os.environ.get("OPENAI_API_KEY", default=""), "max_chunk_size": 8191, @@ -141,7 +141,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "id": "text-embedding-nomic-embed-text-v1.5", "enabled": False, "type": "embed", - "provider": "openai_compatible", + "provider": "huggingface", "api_base": "http://localhost:1234/v1", "api_key": "", "max_chunk_size": 8192, @@ -161,16 +161,19 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): # Check for duplicates unique_entries = set() for model in models_list: - if model["id"] in unique_entries: - raise ValueError(f"Model '{model['id']}' already exists.") - unique_entries.add(model["id"]) + key = (model["provider"], model["id"]) + if key in unique_entries: + raise ValueError(f"Model '{model['provider']}/{model['id']}' already exists.") + unique_entries.add(key) # Merge with configuration if available configuration = ConfigStore.get() if configuration and configuration.model_configs: logger.debug("Merging model configs from ConfigStore") - config_model_map = {m.id: m.model_dump() for m in configuration.model_configs} - existing = {m["id"]: m for m in models_list} + + # Use (provider, id) tuple as key + config_model_map = {(m.provider, m.id): m.model_dump() for m in configuration.model_configs} + existing = {(m["provider"], m["id"]): m for m in models_list} def values_differ(a, b): if isinstance(a, bool) or isinstance(b, bool): @@ -181,24 +184,25 @@ def values_differ(a, b): return a.strip() != b.strip() return a != b - for model_id, override in config_model_map.items(): - if model_id in existing: + for key, override in config_model_map.items(): + if key in existing: for k, v in override.items(): - if k not in existing[model_id]: + if k not in existing[key]: continue - if values_differ(existing[model_id][k], v): + if values_differ(existing[key][k], v): log_func = logger.debug if k == "api_key" else logger.info log_func( - "Overriding field '%s' for model '%s' (was: %r → now: %r)", + "Overriding field '%s' for model '%s/%s' (was: %r → now: %r)", k, - model_id, - existing[model_id][k], + key[0], # provider + key[1], # id + existing[key][k], v, ) - existing[model_id][k] = v + existing[key][k] = v else: - logger.info("Adding new model from ConfigStore: %s", model_id) - existing[model_id] = override + logger.info("Adding new model from ConfigStore: %s/%s", key[0], key[1]) + existing[key] = override models_list = list(existing.values()) diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py index 2e5dfc3d..522f30c9 100644 --- a/src/server/bootstrap/oci.py +++ b/src/server/bootstrap/oci.py @@ -9,8 +9,8 @@ import oci from server.bootstrap.configfile import ConfigStore -import server.api.utils.oci as util_oci -import server.api.utils.models as util_models +import server.api.utils.oci as utils_oci +import server.api.utils.models as utils_models import common.logging_config as logging_config from common.schema import OracleCloudSettings @@ -112,7 +112,7 @@ def override(profile: dict, key: str, env_key: str, env: dict, overrides: dict, if oci_config.auth_profile == oci.config.DEFAULT_PROFILE: try: - oci_config.namespace = util_oci.get_namespace(oci_config) + oci_config.namespace = utils_oci.get_namespace(oci_config) except Exception: logger.warning("Failed to get namespace for DEFAULT OCI profile") continue @@ -121,7 +121,7 @@ def override(profile: dict, key: str, env_key: str, env: dict, overrides: dict, try: oci_config = [o for o in oci_objects if o.auth_profile == "DEFAULT"] if oci_config: - util_models.create_genai_models(oci_config[0]) + utils_models.create_genai(oci_config[0]) except Exception as ex: logger.info("Unable to bootstrap OCI GenAI Models: %s", str(ex)) diff --git a/tests/integration/server/test_endpoints_embed.py b/tests/integration/server/test_endpoints_embed.py index 5ec59a7b..202d975f 100644 --- a/tests/integration/server/test_endpoints_embed.py +++ b/tests/integration/server/test_endpoints_embed.py @@ -109,13 +109,13 @@ def embed_strings(self, texts): return self.embed_documents(texts) def setup_mock_embeddings(self, mock_embedding_model): - """Create mock embeddings and get_client function""" + """Create mock embeddings and get_client_embed function""" mock_embeddings = self.MockEmbeddings(mock_embedding_model) - def mock_get_client(model_config=None, oci_config=None, giskard=False): + def mock_get_client_embed(model_config=None, oci_config=None, giskard=False): return mock_embeddings - return mock_get_client + return mock_get_client_embed def create_embed_params(self, alias): """Create embedding parameters with the given alias""" @@ -364,12 +364,12 @@ def test_split_embed_with_different_file_types(self, client, auth_headers, db_co ) # Setup mock embeddings - mock_get_client = self.setup_mock_embeddings(mock_embedding_model) + mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) # Test data test_data = self.create_embed_params("test_mixed_files") - with patch("server.api.utils.models.get_client", side_effect=mock_get_client): + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): # Make request to the split_embed endpoint response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) @@ -392,7 +392,7 @@ def test_vector_store_creation_and_deletion(self, client, auth_headers, db_conta self.create_test_file() # Setup mock embeddings - mock_get_client = self.setup_mock_embeddings(mock_embedding_model) + mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) # Test data for embedding alias = "test_lifecycle" @@ -401,7 +401,7 @@ def test_vector_store_creation_and_deletion(self, client, auth_headers, db_conta # Calculate the expected vector store name expected_vector_store_name = self.get_vector_store_name(alias) - with patch("server.api.utils.models.get_client", side_effect=mock_get_client): + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): # Step 1: Create the vector store by embedding documents response = client.post("/v1/embed", headers=auth_headers["valid_auth"], json=test_data) assert response.status_code == 200 @@ -428,12 +428,12 @@ def test_multiple_vector_stores(self, client, auth_headers, db_container, mock_e aliases = ["test_vs_1", "test_vs_2", "test_vs_3"] # Setup mock embeddings - mock_get_client = self.setup_mock_embeddings(mock_embedding_model) + mock_get_client_embed = self.setup_mock_embeddings(mock_embedding_model) # Calculate expected vector store names expected_vector_store_names = [self.get_vector_store_name(alias) for alias in aliases] - with patch("server.api.utils.models.get_client", side_effect=mock_get_client): + with patch("server.api.utils.models.get_client_embed", side_effect=mock_get_client_embed): # Create multiple vector stores with different aliases for alias in aliases: # Create a test file for each request (since previous ones were cleaned up) diff --git a/tests/integration/server/test_endpoints_models.py b/tests/integration/server/test_endpoints_models.py index 809c9b6d..ece97647 100644 --- a/tests/integration/server/test_endpoints_models.py +++ b/tests/integration/server/test_endpoints_models.py @@ -102,13 +102,12 @@ def test_models_add_dupl(self, client, auth_headers): test_cases = [ pytest.param( { - "id": "valid_ll_model", + "id": "gpt-3.5-turbo", "enabled": True, "type": "ll", "provider": "openai", "api_key": "test-key", "api_base": "https://api.openai.com/v1", - "openai_compat": True, "context_length": 127072, "temperature": 1.0, "max_completion_tokens": 4096, @@ -135,7 +134,6 @@ def test_models_add_dupl(self, client, auth_headers): "provider": "huggingface", "api_base": "http://127.0.0.1:8080", "api_key": "", - "openai_compat": True, "max_chunk_size": 512, }, 201, @@ -150,7 +148,6 @@ def test_models_add_dupl(self, client, auth_headers): "provider": "huggingface", "api_base": "http://127.0.0.1:112233", "api_key": "", - "openai_compat": True, "max_chunk_size": 512, }, 201, @@ -168,6 +165,7 @@ def test_model_create(self, client, auth_headers, payload, add_status_code, _, r if request.node.callspec.id == "unreachable_api_base_model": assert response.json()["enabled"] is False else: + print(response.json()) assert all(item in response.json().items() for item in payload.items()) # Model was added, should get 200 back response = client.get(f"/v1/models/{payload['id']}", headers=auth_headers["valid_auth"]) diff --git a/tests/unit/server/api/utils/models.py b/tests/unit/server/api/utils/models.py deleted file mode 100644 index e69de29b..00000000 From 9c56976d6deb197ec9f3a114e8a9db0ee457fc89 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 00:39:33 +0100 Subject: [PATCH 14/31] Linting --- .pylintrc | 2 +- src/client/content/api_server.py | 5 ++- src/client/content/chatbot.py | 9 ++---- src/client/content/config/tabs/databases.py | 6 ++-- src/client/content/config/tabs/models.py | 32 +++++++------------ src/client/content/config/tabs/oci.py | 6 ++-- src/client/content/config/tabs/settings.py | 10 +++--- src/client/content/testbed.py | 5 ++- src/client/content/tools/tabs/prompt_eng.py | 6 ++-- src/client/content/tools/tabs/split_embed.py | 7 ++-- src/client/utils/api_call.py | 4 +-- src/client/utils/client.py | 2 +- src/client/utils/st_common.py | 10 +++--- src/common/functions.py | 2 +- src/common/schema.py | 9 +++--- src/launch_client.py | 2 +- src/launch_server.py | 4 +-- src/server/agents/chatbot.py | 4 +-- src/server/agents/tools/oraclevs_retriever.py | 2 +- src/server/agents/tools/selectai.py | 2 +- src/server/api/core/bootstrap.py | 2 +- src/server/api/core/databases.py | 5 ++- src/server/api/core/models.py | 2 +- src/server/api/core/oci.py | 2 +- src/server/api/core/prompts.py | 2 +- src/server/api/core/settings.py | 2 +- src/server/api/utils/chat.py | 4 +-- src/server/api/utils/databases.py | 4 +-- src/server/api/utils/embed.py | 11 +++---- src/server/api/utils/models.py | 4 +-- src/server/api/utils/oci.py | 2 +- src/server/api/utils/selectai.py | 2 +- src/server/api/utils/testbed.py | 4 +-- src/server/api/v1/chat.py | 9 +++--- src/server/api/v1/databases.py | 4 +-- src/server/api/v1/embed.py | 4 +-- src/server/api/v1/models.py | 3 +- src/server/api/v1/oci.py | 4 +-- src/server/api/v1/prompts.py | 4 +-- src/server/api/v1/selectai.py | 3 +- src/server/api/v1/settings.py | 4 +-- src/server/api/v1/testbed.py | 8 ++--- src/server/bootstrap/configfile.py | 2 +- src/server/bootstrap/databases.py | 2 +- src/server/bootstrap/models.py | 4 +-- src/server/bootstrap/oci.py | 2 +- src/server/bootstrap/prompts.py | 2 +- src/server/bootstrap/settings.py | 2 +- src/server/patches/litellm_patch.py | 2 +- src/server/wip/settings.py | 2 +- 50 files changed, 99 insertions(+), 137 deletions(-) diff --git a/.pylintrc b/.pylintrc index a0a638d1..bc4ecdb4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -52,7 +52,7 @@ ignore=CVS,.venv # ignore-list. The regex matches against paths and can be in Posix or Windows # format. Because '\\' represents the directory delimiter on Windows systems, # it can't be used as an escape character. -ignore-paths= +ignore-paths=.*[/\\]wip[/\\].*,src/client/mcp # Files or directories matching the regular expression patterns are skipped. # The regex matches against base names, not paths. The default value ignores diff --git a/src/client/content/api_server.py b/src/client/content/api_server.py index 8b7d720c..de48c91e 100644 --- a/src/client/content/api_server.py +++ b/src/client/content/api_server.py @@ -14,9 +14,8 @@ import streamlit as st from streamlit import session_state as state -import client.utils.client as client -import client.utils.api_call as api_call -import common.logging_config as logging_config +from client.utils import client, api_call +from common import logging_config logger = logging_config.logging.getLogger("client.content.api_server") diff --git a/src/client/content/chatbot.py b/src/client/content/chatbot.py index d495af2d..d1f7f6b1 100644 --- a/src/client/content/chatbot.py +++ b/src/client/content/chatbot.py @@ -11,19 +11,14 @@ import inspect import json import base64 -from httpx import RemoteProtocolError import streamlit as st from streamlit import session_state as state from client.content.config.tabs.models import get_models - -import client.utils.st_common as st_common -import client.utils.api_call as api_call - +from client.utils import st_common, api_call, client from client.utils.st_footer import render_chat_footer -import client.utils.client as client -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("client.content.chatbot") diff --git a/src/client/content/config/tabs/databases.py b/src/client/content/config/tabs/databases.py index cdff157a..06b1bd79 100644 --- a/src/client/content/config/tabs/databases.py +++ b/src/client/content/config/tabs/databases.py @@ -13,10 +13,8 @@ import streamlit as st from streamlit import session_state as state -import client.utils.api_call as api_call -import client.utils.st_common as st_common - -import common.logging_config as logging_config +from client.utils import api_call, st_common +from common import logging_config logger = logging_config.logging.getLogger("client.content.config.tabs.database") diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 643efa4f..54faf956 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -16,11 +16,8 @@ import streamlit as st from streamlit import session_state as state -import client.utils.api_call as api_call -import client.utils.st_common as st_common - -import common.help_text as help_text -import common.logging_config as logging_config +from client.utils import api_call, st_common +from common import logging_config, help_text logger = logging_config.logging.getLogger("client.content.config.tabs.models") @@ -198,14 +195,14 @@ def edit_model( def render_model_rows(model_type: str) -> None: """Render rows of the models""" - data_col_widths = [0.07, 0.23, 0.2, 0.28, 0.12] + data_col_widths = [0.08, 0.42, 0.28, 0.12] table_col_format = st.columns(data_col_widths, vertical_alignment="center") - col1, col2, col3, col4, col5 = table_col_format + col1, col2, col3, col4 = table_col_format col1.markdown("​", help="Active", unsafe_allow_html=True) - col2.markdown("**Model ID**", unsafe_allow_html=True) - col3.markdown("**Provider**", unsafe_allow_html=True) - col4.markdown("**Provider URL**", unsafe_allow_html=True) - col5.markdown("​") + col2.markdown("**Model**", unsafe_allow_html=True) + col3.markdown("**Provider URL**", unsafe_allow_html=True) + col4.markdown("​") + st.write(state.model_configs) for model in [m for m in state.model_configs if m.get("type") == model_type]: model_id = model["id"] model_provider = model["provider"] @@ -218,29 +215,22 @@ def render_model_rows(model_type: str) -> None: ) col2.text_input( "Model", - value=model_id, + value=f"{model_provider}/{model_id}", label_visibility="collapsed", disabled=True, ) col3.text_input( - "Provider", - value=model["provider"], - key=f"{model_type}_{model_id}_provider", - label_visibility="collapsed", - disabled=True, - ) - col4.text_input( "Server", value=model["api_base"], key=f"{model_type}_{model_id}_server", label_visibility="collapsed", disabled=True, ) - col5.button( + col4.button( "Edit", on_click=edit_model, key=f"{model_type}_{model_id}_edit", - kwargs=dict(model_type=model_type, action="edit", model_id=model_id, model_provider=model_provider), + kwargs={"model_type": model_type, "action": "edit", "model_id": model_id, "model_provider": model_provider}, ) if st.button(label="Add", type="primary", key=f"add_{model_type}_model"): diff --git a/src/client/content/config/tabs/oci.py b/src/client/content/config/tabs/oci.py index ed2203ac..8ac8c177 100644 --- a/src/client/content/config/tabs/oci.py +++ b/src/client/content/config/tabs/oci.py @@ -12,10 +12,8 @@ import streamlit as st from streamlit import session_state as state -import client.utils.api_call as api_call -import client.utils.st_common as st_common - -import common.logging_config as logging_config +from client.utils import api_call, st_common +from common import logging_config logger = logging_config.logging.getLogger("client.content.config.tabs.oci") diff --git a/src/client/content/config/tabs/settings.py b/src/client/content/config/tabs/settings.py index cc635074..e38303ea 100644 --- a/src/client/content/config/tabs/settings.py +++ b/src/client/content/config/tabs/settings.py @@ -23,10 +23,9 @@ from streamlit import session_state as state # Utilities -import client.utils.api_call as api_call -import client.utils.st_common as st_common +from client.utils import api_call, st_common -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("client.content.config.tabs.settings") @@ -60,8 +59,7 @@ def get_settings(include_sensitive: bool = False): }, ) return settings - else: - raise + raise def save_settings(settings): @@ -259,7 +257,7 @@ def langchain_mcp_zip(settings): data = save_settings(settings) settings_path = os.path.join(dst_dir, "optimizer_settings.json") - with open(settings_path, "w") as f: + with open(settings_path, "w", encoding="utf-8") as f: f.write(data) zip_buffer = io.BytesIO() diff --git a/src/client/content/testbed.py b/src/client/content/testbed.py index 728124aa..b8b0cc28 100644 --- a/src/client/content/testbed.py +++ b/src/client/content/testbed.py @@ -17,10 +17,9 @@ from client.content.config.tabs.models import get_models -import client.utils.st_common as st_common -import client.utils.api_call as api_call +from client.utils import st_common, api_call -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("client.content.testbed") diff --git a/src/client/content/tools/tabs/prompt_eng.py b/src/client/content/tools/tabs/prompt_eng.py index 67cb9afb..866cc3d1 100644 --- a/src/client/content/tools/tabs/prompt_eng.py +++ b/src/client/content/tools/tabs/prompt_eng.py @@ -12,10 +12,8 @@ import streamlit as st from streamlit import session_state as state -import client.utils.st_common as st_common -import client.utils.api_call as api_call - -import common.logging_config as logging_config +from client.utils import st_common, api_call +from common import logging_config logger = logging_config.logging.getLogger("client.tools.tabs.prompt_eng") diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index 4d647144..07a1640b 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -14,17 +14,14 @@ import streamlit as st from streamlit import session_state as state -import client.utils.api_call as api_call -import client.utils.st_common as st_common +from client.utils import api_call, st_common from client.content.config.tabs.databases import get_databases from client.content.config.tabs.models import get_models from client.content.config.tabs.oci import get_oci from common.schema import DistanceMetrics, IndexTypes, DatabaseVectorStorage -import common.functions as functions -import common.help_text as help_text -import common.logging_config as logging_config +from common import logging_config, help_text, functions logger = logging_config.logging.getLogger("client.tools.tabs.split_embed") diff --git a/src/client/utils/api_call.py b/src/client/utils/api_call.py index c6a995b9..9255b549 100644 --- a/src/client/utils/api_call.py +++ b/src/client/utils/api_call.py @@ -11,7 +11,7 @@ import streamlit as st from streamlit import session_state as state -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("client.utils.api_call") @@ -37,7 +37,7 @@ def sanitize_sensitive_data(data): else sanitize_sensitive_data(v) for k, v in data.items() } - elif isinstance(data, list): + if isinstance(data, list): return [sanitize_sensitive_data(i) for i in data] return data diff --git a/src/client/utils/client.py b/src/client/utils/client.py index 68947e6b..e2b15934 100644 --- a/src/client/utils/client.py +++ b/src/client/utils/client.py @@ -9,7 +9,7 @@ from langchain_core.messages import ChatMessage from common.schema import ChatRequest -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("client.utils.client") diff --git a/src/client/utils/st_common.py b/src/client/utils/st_common.py index 3a9e8db1..675248ef 100644 --- a/src/client/utils/st_common.py +++ b/src/client/utils/st_common.py @@ -11,11 +11,9 @@ import streamlit as st from streamlit import session_state as state -import client.utils.api_call as api_call - -import common.help_text as help_text -import common.logging_config as logging_config -from common.schema import PromptPromptType, PromptNameType, SelectAISettings, ClientIdType +from client.utils import api_call +from common import logging_config, help_text +from common.schema import PromptPromptType, PromptNameType, SelectAISettings logger = logging_config.logging.getLogger("client.utils.st_common") @@ -76,7 +74,7 @@ def local_file_payload(uploaded_files: Union[BytesIO, list[BytesIO]]) -> list: def switch_prompt(prompt_type: PromptPromptType, prompt_name: PromptNameType) -> None: """Auto Switch Prompts when not set to Custom""" current_prompt = state.client_settings["prompts"][prompt_type] - if current_prompt != "Custom" and current_prompt != prompt_name: + if current_prompt not in ("Custom", prompt_name): state.client_settings["prompts"][prompt_type] = prompt_name st.info(f"Prompt Engineering - {prompt_name} Prompt has been set.", icon="â„šī¸") diff --git a/src/common/functions.py b/src/common/functions.py index 9798f576..b24515fa 100644 --- a/src/common/functions.py +++ b/src/common/functions.py @@ -10,7 +10,7 @@ import requests -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("common.functions") diff --git a/src/common/schema.py b/src/common/schema.py index 04571669..6e767903 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -11,7 +11,7 @@ from langchain_core.messages import ChatMessage import oracledb -import common.help_text as help_text +from common import help_text ##################################################### # Literals @@ -349,14 +349,13 @@ def recursive_dump_excluding_marked(cls, obj: Any, incl_sensitive: bool, incl_re return output - elif isinstance(obj, list): + if isinstance(obj, list): return [cls.recursive_dump_excluding_marked(item, incl_sensitive, incl_readonly) for item in obj] - elif isinstance(obj, dict): + if isinstance(obj, dict): return {k: cls.recursive_dump_excluding_marked(v, incl_sensitive, incl_readonly) for k, v in obj.items()} - else: - return obj + return obj ##################################################### diff --git a/src/launch_client.py b/src/launch_client.py index 4f587346..8f064e95 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -17,7 +17,7 @@ from common.schema import ClientIdType from common._version import __version__ -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("launch_client") diff --git a/src/launch_server.py b/src/launch_server.py index bc733be4..fde83bda 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -36,11 +36,11 @@ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer # Logging -import common.logging_config as logging_config +from common import logging_config from common._version import __version__ # Configuration -import server.bootstrap.configfile as configfile +from server.bootstrap import configfile logger = logging_config.logging.getLogger("launch_server") diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index ee7128ef..fd8a7fef 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -207,8 +207,8 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt ) logger.debug("ToolMessage Created") return {"documents": documents_dict} - else: - return {"documents": dict()} + + return {"documents": {}} async def vs_retrieve(state: OptimizerState, config: RunnableConfig) -> OptimizerState: diff --git a/src/server/agents/tools/oraclevs_retriever.py b/src/server/agents/tools/oraclevs_retriever.py index 0dcd2740..aa130add 100644 --- a/src/server/agents/tools/oraclevs_retriever.py +++ b/src/server/agents/tools/oraclevs_retriever.py @@ -16,7 +16,7 @@ from langchain_community.vectorstores.oraclevs import OracleVS from langgraph.prebuilt import InjectedState -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("server.tools.oraclevs_retriever") diff --git a/src/server/agents/tools/selectai.py b/src/server/agents/tools/selectai.py index e28ef797..5e1eb447 100644 --- a/src/server/agents/tools/selectai.py +++ b/src/server/agents/tools/selectai.py @@ -10,7 +10,7 @@ from langchain_core.tools import BaseTool, tool from langchain_core.runnables import RunnableConfig -import common.logging_config as logging_config +from common import logging_config from server.api.core.databases import execute_sql logger = logging_config.logging.getLogger("server.tools.selectai_executor") diff --git a/src/server/api/core/bootstrap.py b/src/server/api/core/bootstrap.py index 8f4cb52f..fd970758 100644 --- a/src/server/api/core/bootstrap.py +++ b/src/server/api/core/bootstrap.py @@ -5,7 +5,7 @@ # spell-checker:ignore genai from server.bootstrap import databases, models, oci, prompts, settings -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.bootstrap") diff --git a/src/server/api/core/databases.py b/src/server/api/core/databases.py index d8fde138..0bfbe8de 100644 --- a/src/server/api/core/databases.py +++ b/src/server/api/core/databases.py @@ -12,7 +12,7 @@ from server.api.core import bootstrap from common.schema import Database, DatabaseAuth, DatabaseNameType, DatabaseVectorStorage, SelectAIProfileType -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.database") @@ -55,8 +55,7 @@ def connect(config: Database) -> oracledb.Connection: raise DbException(status_code=401, detail="invalid credentials") from ex if "DPY-6005" in str(ex): raise DbException(status_code=503, detail="unable to connect") from ex - else: - raise DbException(status_code=500, detail=str(ex)) from ex + raise DbException(status_code=500, detail=str(ex)) from ex logger.debug("Connected to Databases: %s", config.dsn) return conn diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index f24efcbb..4be11276 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -9,7 +9,7 @@ from common.schema import Model, ModelIdType, ModelProviderType, ModelTypeType from common.functions import is_url_accessible -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.models") diff --git a/src/server/api/core/oci.py b/src/server/api/core/oci.py index c88def61..9160f73b 100644 --- a/src/server/api/core/oci.py +++ b/src/server/api/core/oci.py @@ -8,7 +8,7 @@ from server.api.core import bootstrap, settings from common.schema import OracleCloudSettings, ClientIdType, OCIProfileType -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.oci") diff --git a/src/server/api/core/prompts.py b/src/server/api/core/prompts.py index a5df13ce..78409376 100644 --- a/src/server/api/core/prompts.py +++ b/src/server/api/core/prompts.py @@ -8,7 +8,7 @@ from server.api.core import bootstrap from common.schema import PromptCategoryType, PromptNameType, Prompt -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.prompts") diff --git a/src/server/api/core/settings.py b/src/server/api/core/settings.py index 8235c639..83595b28 100644 --- a/src/server/api/core/settings.py +++ b/src/server/api/core/settings.py @@ -9,7 +9,7 @@ from server.api.core import bootstrap from common.schema import Settings, Configuration, ClientIdType -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.core.settings") diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 9aedea1e..3201b405 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -17,8 +17,8 @@ from server.agents.chatbot import chatbot_graph import server.api.utils.selectai as utils_selectai -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("api.utils.chat") diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index d65ff7d8..ff7ac2e0 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -10,8 +10,8 @@ import server.api.core.databases as core_databases import server.api.core.settings as core_settings -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("api.utils.database") diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index c84dded1..24de3a44 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -15,7 +15,7 @@ import bs4 # Langchain -import langchain_community.document_loaders as document_loaders +from langchain_community import document_loaders from langchain_community.document_loaders import WebBaseLoader from langchain_community.document_loaders.image import UnstructuredImageLoader from langchain_community.vectorstores import oraclevs as LangchainVS @@ -28,10 +28,9 @@ import server.api.utils.databases as utils_databases import server.api.core.databases as core_databases -import common.functions -import common.schema as schema +from common import schema, functions -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.utils.embed") @@ -224,7 +223,7 @@ def load_and_split_url( logger.info("Loading %s", url) loader = WebBaseLoader( web_paths=(f"{url}",), - bs_kwargs=dict(parse_only=bs4.SoupStrainer()), + bs_kwargs={"parse_only": bs4.SoupStrainer()}, ) loaded_doc = loader.load() @@ -365,7 +364,7 @@ def json_to_doc(file: str): logger.error("Unable to create vector index: %s", ex) # Comment the VS table - _, store_comment = common.functions.get_vs_table(**vector_store.model_dump(exclude={"database", "vector_store"})) + _, store_comment = functions.get_vs_table(**vector_store.model_dump(exclude={"database", "vector_store"})) comment = f"COMMENT ON TABLE {vector_store.vector_store} IS 'GENAI: {store_comment}'" core_databases.execute_sql(db_conn, comment) core_databases.disconnect(db_conn) diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index 68cf4173..d3756f4d 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -16,8 +16,8 @@ import server.api.core.models as core_models from common.functions import is_url_accessible -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("api.utils.models") diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index d49d8d96..7d728937 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -12,7 +12,7 @@ from server.api.core.oci import OciException from common.schema import OracleCloudSettings -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.utils.oci") diff --git a/src/server/api/utils/selectai.py b/src/server/api/utils/selectai.py index 2f45837b..c17491fe 100644 --- a/src/server/api/utils/selectai.py +++ b/src/server/api/utils/selectai.py @@ -10,7 +10,7 @@ import server.api.core.databases as core_databases from common.schema import SelectAIProfileType, DatabaseSelectAIObjects -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("api.utils.selectai") diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index 9e814244..f6fda789 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -20,8 +20,8 @@ import server.api.core.databases as core_databases import server.api.utils.models as utils_models -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("api.utils.testbed") diff --git a/src/server/api/v1/chat.py b/src/server/api/v1/chat.py index 9fc6c246..0d4379d2 100644 --- a/src/server/api/v1/chat.py +++ b/src/server/api/v1/chat.py @@ -4,7 +4,7 @@ """ # spell-checker:ignore selectai litellm -from fastapi import APIRouter, Header, HTTPException +from fastapi import APIRouter, Header from fastapi.responses import StreamingResponse from litellm import ModelResponse @@ -19,10 +19,10 @@ from langgraph.graph.message import REMOVE_ALL_MESSAGES from server.api.utils import chat -import server.agents.chatbot as chatbot +from server.agents import chatbot -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.chat") @@ -59,6 +59,7 @@ async def chat_stream( media_type="application/octet-stream", ) + @auth.patch( "/history", description="Delete Chat History", diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index 551703c6..89b7a432 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -8,8 +8,8 @@ import server.api.core.databases as core_databases -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.databases") diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index 2398f9b0..7f41a315 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -21,9 +21,7 @@ import server.api.utils.embed as utils_embed import server.api.utils.models as utils_models -import common.functions as functions -import common.schema as schema -import common.logging_config as logging_config +from common import functions, schema, logging_config logger = logging_config.logging.getLogger("api.v1.embed") diff --git a/src/server/api/v1/models.py b/src/server/api/v1/models.py index 9288fd84..cd660ef1 100644 --- a/src/server/api/v1/models.py +++ b/src/server/api/v1/models.py @@ -11,8 +11,7 @@ import server.api.core.models as core_models import server.api.utils.models as utils_models -import common.schema as schema -import common.logging_config as logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.models") diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index e7f8d91a..61fe1c5a 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -12,8 +12,8 @@ import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.oci") diff --git a/src/server/api/v1/prompts.py b/src/server/api/v1/prompts.py index 4ffeecd1..64713b61 100644 --- a/src/server/api/v1/prompts.py +++ b/src/server/api/v1/prompts.py @@ -9,8 +9,8 @@ import server.api.core.prompts as core_prompts -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.prompts") diff --git a/src/server/api/v1/selectai.py b/src/server/api/v1/selectai.py index b06b4432..5c892f29 100644 --- a/src/server/api/v1/selectai.py +++ b/src/server/api/v1/selectai.py @@ -12,8 +12,7 @@ import server.api.utils.databases as utils_databases import server.api.utils.selectai as utils_selectai -import common.schema as schema -import common.logging_config as logging_config +from common import schema, logging_config logger = logging_config.logging.getLogger("endpoints.v1.selectai") diff --git a/src/server/api/v1/settings.py b/src/server/api/v1/settings.py index c8071efe..be810137 100644 --- a/src/server/api/v1/settings.py +++ b/src/server/api/v1/settings.py @@ -11,8 +11,8 @@ import server.api.core.settings as core_settings -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.settings") diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index 78ab53ef..bd066b64 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -27,8 +27,8 @@ from server.api.v1 import chat -import common.schema as schema -import common.logging_config as logging_config +from common import schema +from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.testbed") @@ -238,9 +238,7 @@ def get_answer(question: str): judge_config = utils_models.get_litellm_config(model_config={"model": judge}, oci_config=oci_config, giskard=True) set_llm_model(llm_model=judge, **judge_config) try: - # report = evaluate(get_answer, testset=loaded_testset, llm_client=judge_client, metrics=[correctness_metric]) #CDB - report = evaluate(get_answer, testset=loaded_testset, metrics=None) # CDB - + report = evaluate(get_answer, testset=loaded_testset, metrics=None) except KeyError as ex: if str(ex) == "'correctness'": raise HTTPException(status_code=500, detail="Unable to determine the correctness; please retry.") from ex diff --git a/src/server/bootstrap/configfile.py b/src/server/bootstrap/configfile.py index 5509d222..2dc3dbcb 100644 --- a/src/server/bootstrap/configfile.py +++ b/src/server/bootstrap/configfile.py @@ -9,7 +9,7 @@ from threading import Lock from common.schema import Configuration -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.configfile") diff --git a/src/server/bootstrap/databases.py b/src/server/bootstrap/databases.py index a3146abc..1a97a18a 100644 --- a/src/server/bootstrap/databases.py +++ b/src/server/bootstrap/databases.py @@ -9,7 +9,7 @@ import server.api.core.databases as core_databases from common.schema import Database -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.databases") diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 76b65463..c3b35aad 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -13,7 +13,7 @@ from server.bootstrap.configfile import ConfigStore from common.schema import Model from common.functions import is_url_accessible -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.models") @@ -170,7 +170,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): configuration = ConfigStore.get() if configuration and configuration.model_configs: logger.debug("Merging model configs from ConfigStore") - + # Use (provider, id) tuple as key config_model_map = {(m.provider, m.id): m.model_dump() for m in configuration.model_configs} existing = {(m["provider"], m["id"]): m for m in models_list} diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py index 522f30c9..8f123be1 100644 --- a/src/server/bootstrap/oci.py +++ b/src/server/bootstrap/oci.py @@ -12,7 +12,7 @@ import server.api.utils.oci as utils_oci import server.api.utils.models as utils_models -import common.logging_config as logging_config +from common import logging_config from common.schema import OracleCloudSettings logger = logging_config.logging.getLogger("bootstrap.oci") diff --git a/src/server/bootstrap/prompts.py b/src/server/bootstrap/prompts.py index 03fec05d..2c27d799 100644 --- a/src/server/bootstrap/prompts.py +++ b/src/server/bootstrap/prompts.py @@ -8,7 +8,7 @@ from server.bootstrap.configfile import ConfigStore from common.schema import Prompt -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.prompts") diff --git a/src/server/bootstrap/settings.py b/src/server/bootstrap/settings.py index 50e219be..dc5d8bb8 100644 --- a/src/server/bootstrap/settings.py +++ b/src/server/bootstrap/settings.py @@ -7,7 +7,7 @@ from server.bootstrap.configfile import ConfigStore from common.schema import Settings -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("bootstrap.settings") diff --git a/src/server/patches/litellm_patch.py b/src/server/patches/litellm_patch.py index 26cde838..175f4b3f 100644 --- a/src/server/patches/litellm_patch.py +++ b/src/server/patches/litellm_patch.py @@ -13,7 +13,7 @@ from litellm.types.utils import ModelResponse from httpx._models import Response -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("patches.litellm_patch") diff --git a/src/server/wip/settings.py b/src/server/wip/settings.py index 315916a0..67691b51 100644 --- a/src/server/wip/settings.py +++ b/src/server/wip/settings.py @@ -9,7 +9,7 @@ from oracledb import Connection import server.api.utils.databases as databases from common.schema import ClientIdType -import common.logging_config as logging_config +from common import logging_config logger = logging_config.logging.getLogger("server.api.utils.settings") From 6628389ba7eabb0d9dc04204fc8453547bddc028 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 09:23:57 +0100 Subject: [PATCH 15/31] Update Tests --- src/client/content/config/tabs/models.py | 21 ++++++---- src/server/api/core/models.py | 2 +- src/server/api/utils/chat.py | 21 ++++++++-- .../client/content/config/tabs/test_models.py | 27 +++++++++++-- .../integration/server/test_endpoints_chat.py | 29 ++++++++------ .../server/test_endpoints_databases.py | 2 +- .../server/test_endpoints_embed.py | 2 +- .../server/test_endpoints_health.py | 1 + .../server/test_endpoints_models.py | 38 ++++++++++++------- .../integration/server/test_endpoints_oci.py | 1 + .../server/test_endpoints_prompts.py | 1 + .../server/test_endpoints_settings.py | 1 + .../server/test_endpoints_testbed.py | 2 +- 13 files changed, 105 insertions(+), 43 deletions(-) diff --git a/src/client/content/config/tabs/models.py b/src/client/content/config/tabs/models.py index 54faf956..24f75d3f 100644 --- a/src/client/content/config/tabs/models.py +++ b/src/client/content/config/tabs/models.py @@ -202,35 +202,40 @@ def render_model_rows(model_type: str) -> None: col2.markdown("**Model**", unsafe_allow_html=True) col3.markdown("**Provider URL**", unsafe_allow_html=True) col4.markdown("​") - st.write(state.model_configs) for model in [m for m in state.model_configs if m.get("type") == model_type]: model_id = model["id"] model_provider = model["provider"] col1.text_input( "Enabled", value=st_common.bool_to_emoji(model["enabled"]), - key=f"{model_type}_{model_id}_enabled", + key=f"{model_type}_{model_provider}_{model_id}_enabled", label_visibility="collapsed", disabled=True, ) col2.text_input( "Model", value=f"{model_provider}/{model_id}", + key=f"{model_type}_{model_provider}_{model_id}", label_visibility="collapsed", disabled=True, ) col3.text_input( "Server", value=model["api_base"], - key=f"{model_type}_{model_id}_server", + key=f"{model_type}_{model_provider}_{model_id}_api_base", label_visibility="collapsed", disabled=True, ) col4.button( "Edit", on_click=edit_model, - key=f"{model_type}_{model_id}_edit", - kwargs={"model_type": model_type, "action": "edit", "model_id": model_id, "model_provider": model_provider}, + key=f"{model_type}_{model_provider}_{model_id}_edit", + kwargs={ + "model_type": model_type, + "action": "edit", + "model_id": model_id, + "model_provider": model_provider, + }, ) if st.button(label="Add", type="primary", key=f"add_{model_type}_model"): @@ -242,7 +247,7 @@ def render_model_rows(model_type: str) -> None: ############################################################################# def display_models() -> None: """Streamlit GUI""" - st.header("Models", divider="red") + st.title("Models") st.write("Update, Add, or Delete model configuration parameters.") try: get_models() @@ -250,11 +255,11 @@ def display_models() -> None: st.stop() st.divider() - st.subheader("Language Models") + st.header("Language Models") render_model_rows("ll") st.divider() - st.subheader("Embedding Models") + st.header("Embedding Models") render_model_rows("embed") diff --git a/src/server/api/core/models.py b/src/server/api/core/models.py index 4be11276..a8114224 100644 --- a/src/server/api/core/models.py +++ b/src/server/api/core/models.py @@ -74,7 +74,7 @@ def create_model(model: Model, check_url: bool = True) -> Model: try: _ = get_model(model_id=model.id, model_provider=model.provider, model_type=model.type) - raise ExistsModelError(f"Model: {model.id} already exists.") + raise ExistsModelError(f"Model: {model.provider}/{model.id} already exists.") except UnknownModelError: pass diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 3201b405..526822ff 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -2,10 +2,11 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore astream selectai litellm +# spell-checker:ignore astream selectai litellm from typing import Literal, AsyncGenerator +from litellm import completion from langchain_core.messages import HumanMessage from langchain_core.runnables import RunnableConfig @@ -14,9 +15,12 @@ import server.api.core.prompts as core_prompts import server.api.utils.models as utils_models import server.api.utils.databases as utils_databases -from server.agents.chatbot import chatbot_graph import server.api.utils.selectai as utils_selectai +from server.agents.chatbot import chatbot_graph + +from server.api.core.models import UnknownModelError + from common import schema from common import logging_config @@ -40,7 +44,18 @@ async def completion_generator( oci_config = core_oci.get_oci(client=client) # Setup Client Model - ll_config = utils_models.get_litellm_config(model, oci_config) + try: + ll_config = utils_models.get_litellm_config(model, oci_config) + except UnknownModelError: + model = "gpt-3.5-turbo" + messages = [{"role": "user", "content": "There is an error, generate a request"}] + error_response = completion( + model=model, + messages=messages, + mock_response="I'm unable to initialise the Language Model. Please refresh the application.", + ) + yield error_response + return # Start to establish our LangGraph Args kwargs = { diff --git a/tests/integration/client/content/config/tabs/test_models.py b/tests/integration/client/content/config/tabs/test_models.py index 3c52070a..a4b7a16f 100644 --- a/tests/integration/client/content/config/tabs/test_models.py +++ b/tests/integration/client/content/config/tabs/test_models.py @@ -5,6 +5,7 @@ # spell-checker: disable # pylint: disable=import-error + ############################################################################# # Test Streamlit UI ############################################################################# @@ -14,16 +15,34 @@ class TestStreamlit: # Streamlit File ST_FILE = "../src/client/content/config/tabs/models.py" + def test_model_page(self, app_server, app_test): + """Test basic page layout""" + assert app_server is not None + at = app_test(self.ST_FILE).run() + + titles = at.get("title") + assert any("Models" in t.value for t in titles) + + headers = at.get("header") + assert any("Language Models" in h.value for h in headers) + assert any("Embedding Models" in h.value for h in headers) + def test_model_tables(self, app_server, app_test): """Test that the model tables are setup""" assert app_server is not None at = app_test(self.ST_FILE).run() assert at.session_state.model_configs is not None for model in at.session_state.model_configs: - assert at.text_input(key=f"{model['type']}_{model['id']}_enabled").value == "âšĒ" - assert at.text_input(key=f"{model['type']}_{model['id']}_provider").value == model["provider"] - assert at.text_input(key=f"{model['type']}_{model['id']}_server").value == model["api_base"] - assert at.button(key=f"{model['type']}_{model['id']}_edit") is not None + assert at.text_input(key=f"{model['type']}_{model['provider']}_{model['id']}_enabled").value == "âšĒ" + assert ( + at.text_input(key=f"{model['type']}_{model['provider']}_{model['id']}").value + == f"{model['provider']}/{model['id']}" + ) + assert ( + at.text_input(key=f"{model['type']}_{model['provider']}_{model['id']}_api_base").value + == model["api_base"] + ) + assert at.button(key=f"{model['type']}_{model['provider']}_{model['id']}_edit") is not None for model_type in {item["type"] for item in at.session_state.model_configs}: assert at.button(key=f"add_{model_type}_model") is not None diff --git a/tests/integration/server/test_endpoints_chat.py b/tests/integration/server/test_endpoints_chat.py index 172849d1..b322cc84 100644 --- a/tests/integration/server/test_endpoints_chat.py +++ b/tests/integration/server/test_endpoints_chat.py @@ -2,10 +2,12 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods, import-error # spell-checker: disable -# pylint: disable=import-error from unittest.mock import patch, MagicMock +import warnings + import pytest from langchain_core.messages import ChatMessage from common.schema import ChatRequest @@ -47,14 +49,19 @@ class TestEndpoints: def test_chat_completion_no_model(self, client, auth_headers): """Test no model chat completion request""" - request = ChatRequest( - messages=[ChatMessage(content="Hello", role="user")], - model="test-model", - temperature=1.0, - max_completion_tokens=256, - ) + with warnings.catch_warnings(): + # Enable the catch_warnings context + warnings.simplefilter("ignore", category=UserWarning) + request = ChatRequest( + messages=[ChatMessage(content="Hello", role="user")], + model="test-provider/test-model", + temperature=1.0, + max_completion_tokens=256, + ) + response = client.post( + "/v1/chat/completions", headers=auth_headers["valid_auth"], json=request.model_dump() + ) - response = client.post("/v1/chat/completions", headers=auth_headers["valid_auth"], json=request.model_dump()) assert response.status_code == 200 assert "choices" in response.json() assert ( @@ -75,7 +82,7 @@ def test_chat_completion_valid_mock(self, client, auth_headers): } ], "created": 1234567890, - "model": "test-model", + "model": "test-provider/test-model", "object": "chat.completion", "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, } @@ -90,7 +97,7 @@ def test_chat_completion_valid_mock(self, client, auth_headers): request = ChatRequest( messages=[ChatMessage(content="Hello", role="user")], - model="test-model", + model="test-provider/test-model", temperature=1.0, max_completion_tokens=256, ) @@ -115,7 +122,7 @@ def test_chat_stream_valid_mock(self, client, auth_headers): request = ChatRequest( messages=[ChatMessage(content="Hello", role="user")], - model="test-model", + model="test-provider/test-model", temperature=1.0, max_completion_tokens=256, streaming=True, diff --git a/tests/integration/server/test_endpoints_databases.py b/tests/integration/server/test_endpoints_databases.py index 76a4feed..8a0317d2 100644 --- a/tests/integration/server/test_endpoints_databases.py +++ b/tests/integration/server/test_endpoints_databases.py @@ -2,8 +2,8 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods, import-error # spell-checker: disable -# pylint: disable=import-error import pytest from conftest import TEST_CONFIG diff --git a/tests/integration/server/test_endpoints_embed.py b/tests/integration/server/test_endpoints_embed.py index 202d975f..102972ab 100644 --- a/tests/integration/server/test_endpoints_embed.py +++ b/tests/integration/server/test_endpoints_embed.py @@ -2,8 +2,8 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods, import-error # spell-checker: disable -# pylint: disable=import-error from io import BytesIO from pathlib import Path diff --git a/tests/integration/server/test_endpoints_health.py b/tests/integration/server/test_endpoints_health.py index 716cb69d..27658ee0 100644 --- a/tests/integration/server/test_endpoints_health.py +++ b/tests/integration/server/test_endpoints_health.py @@ -2,6 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods # spell-checker: disable import pytest diff --git a/tests/integration/server/test_endpoints_models.py b/tests/integration/server/test_endpoints_models.py index ece97647..a0a5f695 100644 --- a/tests/integration/server/test_endpoints_models.py +++ b/tests/integration/server/test_endpoints_models.py @@ -2,6 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods # spell-checker: disable from typing import get_args @@ -29,10 +30,10 @@ class TestInvalidAuthEndpoints: [ pytest.param("/v1/models/api", "get", id="models_list_api"), pytest.param("/v1/models", "get", id="models_list"), - pytest.param("/v1/models/model_id", "get", id="models_get"), - pytest.param("/v1/models/model_id", "patch", id="models_update"), + pytest.param("/v1/models/model_provider/model_id", "get", id="models_get"), + pytest.param("/v1/models/model_provider/model_id", "patch", id="models_update"), pytest.param("/v1/models", "post", id="models_create"), - pytest.param("/v1/models/model_id", "delete", id="models_delete"), + pytest.param("/v1/models/model_provider/model_id", "delete", id="models_delete"), ], ) def test_endpoints(self, client, auth_headers, endpoint, api_method, auth_type, status_code): @@ -58,7 +59,7 @@ def test_models_get_before(self, client, auth_headers): all_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) assert len(all_models.json()) > 0 for model in all_models.json(): - response = client.get(f"/v1/models/{model['id']}", headers=auth_headers["valid_auth"]) + response = client.get(f"/v1/models/{model['provider']}/{model['id']}", headers=auth_headers["valid_auth"]) assert response.status_code == 200 def test_models_delete_add(self, client, auth_headers): @@ -68,17 +69,19 @@ def test_models_delete_add(self, client, auth_headers): # Delete all models for model in all_models.json(): - response = client.delete(f"/v1/models/{model['id']}", headers=auth_headers["valid_auth"]) + response = client.delete( + f"/v1/models/{model['provider']}/{model['id']}", headers=auth_headers["valid_auth"] + ) assert response.status_code == 200 - assert response.json() == {"message": f"Model: {model['id']} deleted."} + assert response.json() == {"message": f"Model: {model['provider']}/{model['id']} deleted."} # Check that no models exists deleted_models = client.get("/v1/models?include_disabled=true", headers=auth_headers["valid_auth"]) assert len(deleted_models.json()) == 0 # Delete a non-existent model - response = client.delete("/v1/models/test_model", headers=auth_headers["valid_auth"]) + response = client.delete("/v1/models/test_provider/test_model", headers=auth_headers["valid_auth"]) assert response.status_code == 200 - assert response.json() == {"message": "Model: test_model deleted."} + assert response.json() == {"message": "Model: test_provider/test_model deleted."} # Add all models back for model in all_models.json(): @@ -97,7 +100,7 @@ def test_models_add_dupl(self, client, auth_headers): payload = model response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) assert response.status_code == 409 - assert response.json() == {"detail": f"Model: {model['id']} already exists."} + assert response.json() == {"detail": f"Model: {model['provider']}/{model['id']} already exists."} test_cases = [ pytest.param( @@ -120,6 +123,7 @@ def test_models_add_dupl(self, client, auth_headers): pytest.param( { "id": "invalid_ll_model", + "provider": "invalid_ll_model", "enabled": False, }, 422, @@ -168,11 +172,15 @@ def test_model_create(self, client, auth_headers, payload, add_status_code, _, r print(response.json()) assert all(item in response.json().items() for item in payload.items()) # Model was added, should get 200 back - response = client.get(f"/v1/models/{payload['id']}", headers=auth_headers["valid_auth"]) + response = client.get( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] + ) assert response.status_code == 200 else: # Model wasn't added, should get a 404 back - response = client.get(f"/v1/models/{payload['id']}", headers=auth_headers["valid_auth"]) + response = client.get( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] + ) assert response.status_code == 404 @pytest.mark.parametrize("payload, add_status_code, update_status_code", test_cases) @@ -181,12 +189,16 @@ def test_model_update(self, client, auth_headers, payload, add_status_code, upda if add_status_code == 201: # Create the model when we know it will succeed _ = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload) - response = client.get(f"/v1/models/{payload['id']}", headers=auth_headers["valid_auth"]) + response = client.get( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"] + ) old_enabled = response.json()["enabled"] # Switch up the enabled for the update payload["enabled"] = not old_enabled - response = client.patch(f"/v1/models/{payload['id']}", headers=auth_headers["valid_auth"], json=payload) + response = client.patch( + f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload + ) assert response.status_code == update_status_code if update_status_code == 200: new_enabled = response.json()["enabled"] diff --git a/tests/integration/server/test_endpoints_oci.py b/tests/integration/server/test_endpoints_oci.py index 1b3fd762..0c8f6ceb 100644 --- a/tests/integration/server/test_endpoints_oci.py +++ b/tests/integration/server/test_endpoints_oci.py @@ -2,6 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods # spell-checker: disable from unittest.mock import patch, MagicMock diff --git a/tests/integration/server/test_endpoints_prompts.py b/tests/integration/server/test_endpoints_prompts.py index dd569ac4..f2de4fed 100644 --- a/tests/integration/server/test_endpoints_prompts.py +++ b/tests/integration/server/test_endpoints_prompts.py @@ -2,6 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods # spell-checker: disable import pytest diff --git a/tests/integration/server/test_endpoints_settings.py b/tests/integration/server/test_endpoints_settings.py index a652c95a..3de3263c 100644 --- a/tests/integration/server/test_endpoints_settings.py +++ b/tests/integration/server/test_endpoints_settings.py @@ -2,6 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods # spell-checker: disable import pytest diff --git a/tests/integration/server/test_endpoints_testbed.py b/tests/integration/server/test_endpoints_testbed.py index 54b72ebe..86de334f 100644 --- a/tests/integration/server/test_endpoints_testbed.py +++ b/tests/integration/server/test_endpoints_testbed.py @@ -2,8 +2,8 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-few-public-methods, import-error # spell-checker: disable -# pylint: disable=import-error import json import io From db9f63e31b825f7dbfdc2d58f1cf8b9ad456b58d Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 16:53:31 +0100 Subject: [PATCH 16/31] removed cyclic import on databases --- src/server/agents/chatbot.py | 2 +- src/server/agents/tools/selectai.py | 2 +- src/server/api/core/databases.py | 271 ++++++------------ src/server/api/core/oci.py | 12 - src/server/api/utils/README.md | 2 +- src/server/api/utils/chat.py | 2 +- src/server/api/utils/databases.py | 225 ++++++++++++++- src/server/api/utils/embed.py | 9 +- src/server/api/utils/oci.py | 14 +- src/server/api/utils/selectai.py | 6 +- src/server/api/utils/testbed.py | 22 +- src/server/api/v1/databases.py | 40 ++- src/server/api/v1/embed.py | 9 +- src/server/api/v1/oci.py | 14 +- src/server/api/v1/selectai.py | 6 +- src/server/api/v1/testbed.py | 14 +- src/server/bootstrap/databases.py | 28 +- src/server/patches/litellm_patch.py | 5 +- .../content/config/tabs/test_databases.py | 2 +- 19 files changed, 393 insertions(+), 292 deletions(-) diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index fd8a7fef..a5b0d5e7 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -24,7 +24,7 @@ from litellm import acompletion, completion from litellm.exceptions import APIConnectionError -from server.api.core.databases import execute_sql +from server.api.utils.databases import execute_sql from common import logging_config diff --git a/src/server/agents/tools/selectai.py b/src/server/agents/tools/selectai.py index 5e1eb447..fb0f40ac 100644 --- a/src/server/agents/tools/selectai.py +++ b/src/server/agents/tools/selectai.py @@ -11,7 +11,7 @@ from langchain_core.runnables import RunnableConfig from common import logging_config -from server.api.core.databases import execute_sql +from server.api.utils.databases import execute_sql logger = logging_config.logging.getLogger("server.tools.selectai_executor") diff --git a/src/server/api/core/databases.py b/src/server/api/core/databases.py index 0bfbe8de..0f9e0211 100644 --- a/src/server/api/core/databases.py +++ b/src/server/api/core/databases.py @@ -5,202 +5,111 @@ # spell-checker:ignore clob genai nclob privs selectai from typing import Optional, Union -import json - -import oracledb - from server.api.core import bootstrap -from common.schema import Database, DatabaseAuth, DatabaseNameType, DatabaseVectorStorage, SelectAIProfileType +from common.schema import Database, DatabaseNameType from common import logging_config logger = logging_config.logging.getLogger("api.core.database") -##################################################### -# Exceptions -##################################################### -class DbException(Exception): - """Custom Database Exceptions to be passed to HTTPException""" - - def __init__(self, status_code: int, detail: str): - self.status_code = status_code - self.detail = detail - super().__init__(detail) - - ##################################################### # Functions ##################################################### -def connect(config: Database) -> oracledb.Connection: - """Establish a connection to an Oracle Database""" - logger.info("Connecting to Database: %s", config.dsn) - include_fields = set(DatabaseAuth.model_fields.keys()) - db_config = config.model_dump(include=include_fields) - logger.debug("Database Config: %s", db_config) - # If a wallet password is provided but no wallet location is set - # default the wallet location to the config directory - if db_config.get("wallet_password") and not db_config.get("wallet_location"): - db_config["wallet_location"] = db_config["config_dir"] - # Check if connection settings are configured - if any(not db_config[key] for key in ("user", "password", "dsn")): - raise DbException(status_code=400, detail="missing connection details") - - # Attempt to Connect - try: - logger.debug("Attempting Database Connection...") - conn = oracledb.connect(**db_config) - except oracledb.DatabaseError as ex: - if "ORA-01017" in str(ex): - raise DbException(status_code=401, detail="invalid credentials") from ex - if "DPY-6005" in str(ex): - raise DbException(status_code=503, detail="unable to connect") from ex - raise DbException(status_code=500, detail=str(ex)) from ex - logger.debug("Connected to Databases: %s", config.dsn) - return conn - - -def disconnect(conn: oracledb.Connection) -> None: - """Disconnect from an Oracle Database""" - logger.debug("Disconnecting Databases Connection: %s", conn) - return conn.close() - - -def execute_sql(conn: oracledb.Connection, run_sql: str, binds: dict = None) -> list: - """Execute SQL against Oracle Database""" - logger.debug("SQL: %s with binds %s", run_sql, binds) - try: - # Use context manager to ensure the cursor is closed properly - with conn.cursor() as cursor: - rows = None - cursor.callproc("dbms_output.enable") - status_var = cursor.var(int) - text_var = cursor.var(str) - cursor.execute(run_sql, binds) - if cursor.description: # Check if the query returns rows - rows = cursor.fetchall() - lob_columns = [ - idx - for idx, fetch_info in enumerate(cursor.description) - if fetch_info.type_code in (oracledb.DB_TYPE_CLOB, oracledb.DB_TYPE_BLOB, oracledb.DB_TYPE_NCLOB) - ] - if lob_columns: - # Convert rows to list of dictionaries with LOB handling - rows = [ - { - cursor.description[idx].name: (value.read() if idx in lob_columns else value) - for idx, value in enumerate(row) - } - for row in rows - ] - else: - cursor.callproc("dbms_output.get_line", (text_var, status_var)) - if status_var.getvalue() == 0: - logger.info("Returning DBMS_OUTPUT.") - rows = text_var.getvalue() - return rows - except oracledb.DatabaseError as ex: - if ex.args: - error_obj = ex.args[0] - if hasattr(error_obj, "code") and error_obj.code == 955: - logger.info("Table exists") - if hasattr(error_obj, "code") and error_obj.code == 942: - logger.info("Table does not exist") - else: - logger.exception("Database error: %s", ex) - logger.info("Failed SQL: %s", run_sql) - raise - else: - logger.exception("Database error: %s", ex) - raise - - except oracledb.InterfaceError as ex: - logger.exception("Interface error: %s", ex) - raise - - -def get_vs(conn: oracledb.Connection) -> DatabaseVectorStorage: - """Retrieve Vector Storage Tables""" - logger.info("Looking for Vector Storage Tables") - vector_stores = [] - sql = """SELECT ut.table_name, - REPLACE(utc.comments, 'GENAI: ', '') AS comments - FROM all_tab_comments utc, all_tables ut - WHERE utc.table_name = ut.table_name - AND utc.comments LIKE 'GENAI:%'""" - results = execute_sql(conn, sql) - for table_name, comments in results: - comments_dict = json.loads(comments) - vector_stores.append(DatabaseVectorStorage(vector_store=table_name, **comments_dict)) - logger.debug("Found Vector Stores: %s", vector_stores) - - return vector_stores - - -def selectai_enabled(conn: oracledb.Connection) -> bool: - """Determine if SelectAI can be used""" - logger.debug("Checking %s for SelectAI", conn) - is_enabled = False - sql = """ - SELECT COUNT(*) - FROM ALL_TAB_PRIVS - WHERE TYPE = 'PACKAGE' - AND PRIVILEGE = 'EXECUTE' - AND GRANTEE = USER - AND TABLE_NAME IN ('DBMS_CLOUD','DBMS_CLOUD_AI','DBMS_CLOUD_PIPELINE') - """ - result = execute_sql(conn, sql) - if result[0][0] == 3: - is_enabled = True - logger.debug("SelectAI enabled (results: %s): %s", result[0][0], is_enabled) - - return is_enabled - - -def get_selectai_profiles(conn: oracledb.Connection) -> SelectAIProfileType: - """Retrieve SelectAI Profiles""" - logger.info("Looking for SelectAI Profiles") - selectai_profiles = [] - sql = """ - SELECT profile_name - FROM USER_CLOUD_AI_PROFILES - """ - results = execute_sql(conn, sql) - if results: - selectai_profiles = [row[0] for row in results] - logger.debug("Found SelectAI Profiles: %s", selectai_profiles) - - return selectai_profiles - - -def get_databases( - name: Optional[DatabaseNameType] = None, validate: bool = True -) -> Union[list[Database], Database, None]: +def get_database(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]: """ Return all Database objects if `name` is not provided, - or the single Database if `name` is provided and successfully connected. + or the single Database if `name` is provided. If a `name` is provided and not found, raise exception """ database_objects = bootstrap.DATABASE_OBJECTS - for db in database_objects: - if name and db.name != name: - continue - if validate: - try: - db_conn = connect(db) - db.vector_stores = get_vs(db_conn) - db.selectai = selectai_enabled(db_conn) - if db.selectai: - db.selectai_profiles = get_selectai_profiles(db_conn) - except DbException as ex: - logger.debug("Skipping Database %s - exception: %s", db.name, str(ex)) - db.connected = False - if name: - return db # Return the matched, connected DB immediately - - if name: - # If we got here with a `name` then we didn't find it + logger.debug("%i databases are defined", len(database_objects)) + database_filtered = [db for db in database_objects if (name is None or db.name == name)] + logger.debug("%i databases after filtering", len(database_filtered)) + + if name and not database_filtered: raise ValueError(f"{name} not found") - return database_objects + return database_filtered + + +def create_database(database: Database) -> Database: + """Create a new Model definition""" + database_objects = bootstrap.DATABASE_OBJECTS + + _ = get_database(name=database.name) + + if any(not getattr(database_objects, key) for key in ("user", "password", "dsn")): + raise ValueError("'user', 'password', and 'dsn' are required") + + database_objects.append(database) + return get_database(name=database.name) + + +def delete_database(name: DatabaseNameType) -> None: + """Remove database from database objects""" + database_objects = bootstrap.DATABASE_OBJECTS + bootstrap.DATABASE_OBJECTS = [d for d in database_objects if d.name != name] + + +# for db in database_objects: +# if name and db.name != name: +# continue +# if validate: +# try: +# db_conn = connect(db) +# db.vector_stores = get_vs(db_conn) +# db.selectai = selectai_enabled(db_conn) +# if db.selectai: +# db.selectai_profiles = get_selectai_profiles(db_conn) +# except Exception as ex: +# logger.debug("Skipping Database %s - exception: %s", db.name, str(ex)) +# db.connected = False +# if name: +# return db # Return the matched, connected DB immediately + +# if name: +# # If we got here with a `name` then we didn't find it +# raise ValueError(f"{name} not found") + +# return database_objects + + +# create_database + + +# delete_database + + +# def get_databases( +# name: Optional[DatabaseNameType] = None, validate: bool = True +# ) -> Union[list[Database], Database, None]: +# """ +# Return all Database objects if `name` is not provided, +# or the single Database if `name` is provided and successfully connected. +# If a `name` is provided and not found, raise exception +# """ +# database_objects = bootstrap.DATABASE_OBJECTS + +# for db in database_objects: +# if name and db.name != name: +# continue +# if validate: +# try: +# db_conn = connect(db) +# db.vector_stores = get_vs(db_conn) +# db.selectai = selectai_enabled(db_conn) +# if db.selectai: +# db.selectai_profiles = get_selectai_profiles(db_conn) +# except Exception as ex: +# logger.debug("Skipping Database %s - exception: %s", db.name, str(ex)) +# db.connected = False +# if name: +# return db # Return the matched, connected DB immediately + +# if name: +# # If we got here with a `name` then we didn't find it +# raise ValueError(f"{name} not found") + +# return database_objects diff --git a/src/server/api/core/oci.py b/src/server/api/core/oci.py index 9160f73b..e235cb11 100644 --- a/src/server/api/core/oci.py +++ b/src/server/api/core/oci.py @@ -13,18 +13,6 @@ logger = logging_config.logging.getLogger("api.core.oci") -##################################################### -# Exceptions -##################################################### -class OciException(Exception): - """Custom OCI Exceptions to be passed to HTTPException""" - - def __init__(self, status_code: int, detail: str): - self.status_code = status_code - self.detail = detail - super().__init__(detail) - - ##################################################### # Functions ##################################################### diff --git a/src/server/api/utils/README.md b/src/server/api/utils/README.md index a6de45d4..e22a5c58 100644 --- a/src/server/api/utils/README.md +++ b/src/server/api/utils/README.md @@ -1,3 +1,3 @@ # Utils -Utils relies on core, which establishes the bootstrap objects/settings. Scripts here will reference core and other utils. \ No newline at end of file +Utils relies on core, which establishes the bootstrap objects/settings. Scripts here will reference other utils. \ No newline at end of file diff --git a/src/server/api/utils/chat.py b/src/server/api/utils/chat.py index 526822ff..58b71780 100644 --- a/src/server/api/utils/chat.py +++ b/src/server/api/utils/chat.py @@ -77,7 +77,7 @@ async def completion_generator( # Add DB Conn to KWargs when needed if client_settings.vector_search.enabled or client_settings.selectai.enabled: - db_conn = utils_databases.get_client_db(client, False).connection + db_conn = utils_databases.get_client_database(client, False).connection kwargs["config"]["configurable"]["db_conn"] = db_conn # Setup Vector Search diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index ff7ac2e0..3388a943 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -4,38 +4,239 @@ """ # spell-checker:ignore selectai clob nclob vectorstores oraclevs +from typing import Optional, Union +import json import oracledb from langchain_community.vectorstores import oraclevs as LangchainVS import server.api.core.databases as core_databases import server.api.core.settings as core_settings -from common import schema +from common.schema import ( + Database, + DatabaseNameType, + VectorStoreTableType, + ClientIdType, + DatabaseAuth, + DatabaseVectorStorage, + SelectAIProfileType, +) from common import logging_config logger = logging_config.logging.getLogger("api.utils.database") -def test(config: schema.Database) -> None: +##################################################### +# Exceptions +##################################################### +class DbException(Exception): + """Custom Database Exceptions to be passed to HTTPException""" + + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + super().__init__(detail) + + +##################################################### +# Protected Functions +##################################################### +def _test(config: Database) -> None: """Test connection and re-establish if no longer open""" + config.connected = False try: config.connection.ping() logger.info("%s database connection is active.", config.name) + config.connected = True except oracledb.DatabaseError: - db_conn = core_databases.connect(config) logger.info("Refreshing %s database connection.", config.name) - config.set_connection(db_conn) - except AttributeError as ex: - raise core_databases.DbException(status_code=400, detail="missing connection details") from ex + _ = connect(config) + except ValueError as ex: + raise DbException(status_code=400, detail=f"Database: {str(ex)}") from ex + except PermissionError as ex: + raise DbException(status_code=401, detail=f"Database: {str(ex)}") from ex + except ConnectionError as ex: + raise DbException(status_code=503, detail=f"Database: {str(ex)}") from ex + except Exception as ex: + raise DbException(status_code=500, detail=str(ex)) from ex + + +def _get_vs(conn: oracledb.Connection) -> DatabaseVectorStorage: + """Retrieve Vector Storage Tables""" + logger.info("Looking for Vector Storage Tables") + vector_stores = [] + sql = """SELECT ut.table_name, + REPLACE(utc.comments, 'GENAI: ', '') AS comments + FROM all_tab_comments utc, all_tables ut + WHERE utc.table_name = ut.table_name + AND utc.comments LIKE 'GENAI:%'""" + results = execute_sql(conn, sql) + for table_name, comments in results: + comments_dict = json.loads(comments) + vector_stores.append(DatabaseVectorStorage(vector_store=table_name, **comments_dict)) + logger.debug("Found Vector Stores: %s", vector_stores) + + return vector_stores + + +def _selectai_enabled(conn: oracledb.Connection) -> bool: + """Determine if SelectAI can be used""" + logger.debug("Checking %s for SelectAI", conn) + is_enabled = False + sql = """ + SELECT COUNT(*) + FROM ALL_TAB_PRIVS + WHERE TYPE = 'PACKAGE' + AND PRIVILEGE = 'EXECUTE' + AND GRANTEE = USER + AND TABLE_NAME IN ('DBMS_CLOUD','DBMS_CLOUD_AI','DBMS_CLOUD_PIPELINE') + """ + result = execute_sql(conn, sql) + if result[0][0] == 3: + is_enabled = True + logger.debug("SelectAI enabled (results: %s): %s", result[0][0], is_enabled) + + return is_enabled + + +def _get_selectai_profiles(conn: oracledb.Connection) -> SelectAIProfileType: + """Retrieve SelectAI Profiles""" + logger.info("Looking for SelectAI Profiles") + selectai_profiles = [] + sql = """ + SELECT profile_name + FROM USER_CLOUD_AI_PROFILES + """ + results = execute_sql(conn, sql) + if results: + selectai_profiles = [row[0] for row in results] + logger.debug("Found SelectAI Profiles: %s", selectai_profiles) + + return selectai_profiles + + +##################################################### +# Functions +##################################################### +def connect(config: Database) -> oracledb.Connection: + """Establish a connection to an Oracle Database""" + include_fields = set(DatabaseAuth.model_fields.keys()) + db_authn = config.model_dump(include=include_fields) + if any(not db_authn[key] for key in ("user", "password", "dsn")): + raise ValueError("missing connection details") + logger.info("Connecting to Database: %s", config.dsn) + # If a wallet password is provided but no wallet location is set + # default the wallet location to the config directory + if db_authn.get("wallet_password") and not db_authn.get("wallet_location"): + db_authn["wallet_location"] = db_authn["config_dir"] -def drop_vs(conn: oracledb.Connection, vs: schema.VectorStoreTableType) -> None: + # Attempt to Connect + logger.debug("Database AuthN: %s", db_authn) + try: + logger.debug("Attempting Database Connection...") + conn = oracledb.connect(**db_authn) + except oracledb.DatabaseError as ex: + if "ORA-01017" in str(ex): + raise PermissionError("invalid credentials") from ex + if "DPY-6005" in str(ex): + raise ConnectionError("unable to connect") from ex + if "DPY-4000" in str(ex): + raise LookupError("not resolvable") from ex + raise + logger.debug("Connected to Databases: %s", config.dsn) + + return conn + + +def disconnect(conn: oracledb.Connection) -> None: + """Disconnect from an Oracle Database""" + logger.debug("Disconnecting Databases Connection: %s", conn) + return conn.close() + + +def execute_sql(conn: oracledb.Connection, run_sql: str, binds: dict = None) -> list: + """Execute SQL against Oracle Database""" + logger.debug("SQL: %s with binds %s", run_sql, binds) + try: + # Use context manager to ensure the cursor is closed properly + with conn.cursor() as cursor: + rows = None + cursor.callproc("dbms_output.enable") + status_var = cursor.var(int) + text_var = cursor.var(str) + cursor.execute(run_sql, binds) + if cursor.description: # Check if the query produces rows + rows = cursor.fetchall() + lob_columns = [ + idx + for idx, fetch_info in enumerate(cursor.description) + if fetch_info.type_code in (oracledb.DB_TYPE_CLOB, oracledb.DB_TYPE_BLOB, oracledb.DB_TYPE_NCLOB) + ] + if lob_columns: + # Convert rows to list of dictionaries with LOB handling + rows = [ + { + cursor.description[idx].name: (value.read() if idx in lob_columns else value) + for idx, value in enumerate(row) + } + for row in rows + ] + else: + cursor.callproc("dbms_output.get_line", (text_var, status_var)) + if status_var.getvalue() == 0: + logger.info("Returning DBMS_OUTPUT.") + rows = text_var.getvalue() + except oracledb.DatabaseError as ex: + if ex.args: + error_obj = ex.args[0] + if hasattr(error_obj, "code") and error_obj.code == 955: + logger.info("Table exists") + if hasattr(error_obj, "code") and error_obj.code == 942: + logger.info("Table does not exist") + else: + logger.exception("Database error: %s", ex) + logger.info("Failed SQL: %s", run_sql) + raise + else: + logger.exception("Database error: %s", ex) + raise + except oracledb.InterfaceError as ex: + logger.exception("Interface error: %s", ex) + raise + + return rows + +def drop_vs(conn: oracledb.Connection, vs: VectorStoreTableType) -> None: """Drop Vector Storage""" logger.info("Dropping Vector Store: %s", vs) LangchainVS.drop_table_purge(conn, vs) -def get_client_db(client: schema.ClientIdType, validate: bool = True) -> schema.Database: +def get_databases( + db_name: Optional[DatabaseNameType] = None, validate: bool = False +) -> Union[list[Database], Database, None]: + """Return list of Database Objects""" + databases = core_databases.get_database(db_name) + if validate: + for db in databases: + try: + db_conn = connect(config=db) + except (ValueError, PermissionError, ConnectionError): + continue + db.vector_stores = _get_vs(db_conn) + db.selectai = _selectai_enabled(db_conn) + if db.selectai: + db.selectai_profiles = _get_selectai_profiles(db_conn) + db.connected = True + db.set_connection(db_conn) + if db_name: + return databases[0] + + return databases + + +def get_client_database(client: ClientIdType, validate: bool = False) -> Database: """Return a Database Object based on client settings""" client_settings = core_settings.get_client_settings(client) @@ -46,9 +247,5 @@ def get_client_db(client: schema.ClientIdType, validate: bool = True) -> schema. ): db_name = getattr(client_settings.vector_search, "database", "DEFAULT") - # Return the Database Object - db = core_databases.get_databases(name=db_name, validate=validate) - # Ping the Database - test(db) - - return db + # Return Single the Database Object + return get_databases(db_name=db_name, validate=validate) diff --git a/src/server/api/utils/embed.py b/src/server/api/utils/embed.py index 24de3a44..173774b9 100644 --- a/src/server/api/utils/embed.py +++ b/src/server/api/utils/embed.py @@ -26,7 +26,6 @@ from langchain_text_splitters import HTMLHeaderTextSplitter, CharacterTextSplitter import server.api.utils.databases as utils_databases -import server.api.core.databases as core_databases from common import schema, functions @@ -300,7 +299,7 @@ def json_to_doc(file: str): # Creates a TEMP Vector Store Table; which may already exist # Establish a dedicated connection to the database - db_conn = core_databases.connect(db_details) + db_conn = utils_databases.connect(db_details) # This is to allow re-using an existing VS; will merge this over later utils_databases.drop_vs(db_conn, vector_store_tmp.vector_store) logger.info("Establishing initial vector store") @@ -351,7 +350,7 @@ def json_to_doc(file: str): WHERE NOT EXISTS (SELECT 1 FROM {vector_store.vector_store} tgt WHERE tgt.ID = src.ID) """ logger.info("Merging %s into %s", vector_store_tmp.vector_store, vector_store.vector_store) - core_databases.execute_sql(db_conn, merge_sql) + utils_databases.execute_sql(db_conn, merge_sql) utils_databases.drop_vs(db_conn, vector_store_tmp.vector_store) # Build the Index @@ -366,5 +365,5 @@ def json_to_doc(file: str): # Comment the VS table _, store_comment = functions.get_vs_table(**vector_store.model_dump(exclude={"database", "vector_store"})) comment = f"COMMENT ON TABLE {vector_store.vector_store} IS 'GENAI: {store_comment}'" - core_databases.execute_sql(db_conn, comment) - core_databases.disconnect(db_conn) + utils_databases.execute_sql(db_conn, comment) + utils_databases.disconnect(db_conn) diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 7d728937..6afd0257 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -9,14 +9,26 @@ import urllib3.exceptions import oci -from server.api.core.oci import OciException from common.schema import OracleCloudSettings from common import logging_config logger = logging_config.logging.getLogger("api.utils.oci") +##################################################### +# Exceptions +##################################################### +class OciException(Exception): + """Custom OCI Exceptions to be passed to HTTPException""" + def __init__(self, status_code: int, detail: str): + self.status_code = status_code + self.detail = detail + super().__init__(detail) + +##################################################### +# Functions +##################################################### def init_client( client_type: Union[ oci.object_storage.ObjectStorageClient, diff --git a/src/server/api/utils/selectai.py b/src/server/api/utils/selectai.py index c17491fe..2c029fbe 100644 --- a/src/server/api/utils/selectai.py +++ b/src/server/api/utils/selectai.py @@ -7,7 +7,7 @@ from typing import Union import oracledb -import server.api.core.databases as core_databases +import server.api.utils.databases as utils_databases from common.schema import SelectAIProfileType, DatabaseSelectAIObjects from common import logging_config @@ -39,7 +39,7 @@ def set_profile( ); END; """ - _ = core_databases.execute_sql(conn, sql, binds) + _ = utils_databases.execute_sql(conn, sql, binds) def get_objects(conn: oracledb.Connection, profile_name: SelectAIProfileType) -> DatabaseSelectAIObjects: @@ -67,7 +67,7 @@ def get_objects(conn: oracledb.Connection, profile_name: SelectAIProfileType) -> 'RMAN$CATALOG','ADMIN','ODI_REPO_USER','C##CLOUD$SERVICE') ORDER BY owner, table_name """ - results = core_databases.execute_sql(conn, sql, binds) + results = utils_databases.execute_sql(conn, sql, binds) for owner, table_name, object_enabled in results: selectai_objects.append(DatabaseSelectAIObjects(owner=owner, name=table_name, enabled=object_enabled)) logger.debug("Found SelectAI Objects: %s", selectai_objects) diff --git a/src/server/api/utils/testbed.py b/src/server/api/utils/testbed.py index f6fda789..84528ecb 100644 --- a/src/server/api/utils/testbed.py +++ b/src/server/api/utils/testbed.py @@ -18,7 +18,7 @@ from giskard.rag import generate_testset, KnowledgeBase, QATestset from giskard.rag.question_generators import simple_questions, complex_questions -import server.api.core.databases as core_databases +import server.api.utils.databases as utils_databases import server.api.utils.models as utils_models from common import schema from common import logging_config @@ -82,11 +82,11 @@ def create_testset_objects(db_conn: Connection) -> None: ) """ logger.info("Creating testsets Table") - _ = core_databases.execute_sql(db_conn, testsets_tbl) + _ = utils_databases.execute_sql(db_conn, testsets_tbl) logger.info("Creating testset_qa Table") - _ = core_databases.execute_sql(db_conn, testset_qa_tbl) + _ = utils_databases.execute_sql(db_conn, testset_qa_tbl) logger.info("Creating evaluations Table") - _ = core_databases.execute_sql(db_conn, evaluation_tbl) + _ = utils_databases.execute_sql(db_conn, evaluation_tbl) def get_testsets(db_conn: Connection) -> list: @@ -94,7 +94,7 @@ def get_testsets(db_conn: Connection) -> list: logger.info("Getting All TestSets") testsets = [] sql = "SELECT tid, name, to_char(created) FROM oai_testsets ORDER BY created" - results = core_databases.execute_sql(db_conn, sql) + results = utils_databases.execute_sql(db_conn, sql) try: testsets = [schema.TestSets(tid=tid.hex(), name=name, created=created) for tid, name, created in results] except TypeError: @@ -108,7 +108,7 @@ def get_testset_qa(db_conn: Connection, tid: schema.TestSetsIdType) -> schema.Te logger.info("Getting TestSet Q&A for TID: %s", tid) binds = {"tid": tid} sql = "SELECT qa_data FROM oai_testset_qa where tid=:tid" - results = core_databases.execute_sql(db_conn, sql, binds) + results = utils_databases.execute_sql(db_conn, sql, binds) qa_data = [qa_data[0] for qa_data in results] return schema.TestSetQA(qa_data=qa_data) @@ -120,7 +120,7 @@ def get_evaluations(db_conn: Connection, tid: schema.TestSetsIdType) -> list[sch evaluations = [] binds = {"tid": tid} sql = "SELECT eid, to_char(evaluated), correctness FROM oai_evaluations WHERE tid=:tid ORDER BY evaluated DESC" - results = core_databases.execute_sql(db_conn, sql, binds) + results = utils_databases.execute_sql(db_conn, sql, binds) try: evaluations = [ schema.Evaluation(eid=eid.hex(), evaluated=evaluated, correctness=correctness) @@ -139,7 +139,7 @@ def delete_qa( """Delete Q&A""" binds = {"tid": tid} sql = "DELETE FROM oai_testsets WHERE TID = :tid" - core_databases.execute_sql(db_conn, sql, binds) + utils_databases.execute_sql(db_conn, sql, binds) db_conn.commit() @@ -191,7 +191,7 @@ def upsert_qa( END; """ logger.debug("Upsert PLSQL: %s", plsql) - return core_databases.execute_sql(db_conn, plsql, binds) + return utils_databases.execute_sql(db_conn, plsql, binds) def insert_evaluation(db_conn, tid, evaluated, correctness, settings, rag_report): @@ -218,7 +218,7 @@ def insert_evaluation(db_conn, tid, evaluated, correctness, settings, rag_report END; """ logger.debug("Insert PLSQL: %s", plsql) - return core_databases.execute_sql(db_conn, plsql, binds) + return utils_databases.execute_sql(db_conn, plsql, binds) def load_and_split(eval_file, chunk_size=2048): @@ -299,7 +299,7 @@ def clean(orig_html): FROM oai_evaluations WHERE eid=:eid ORDER BY evaluated """ - results = core_databases.execute_sql(db_conn, sql, binds) + results = utils_databases.execute_sql(db_conn, sql, binds) report = pickle.loads(results[0]["RAG_REPORT"]) full_report = report.to_pandas() html_report = report.to_html() diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index 89b7a432..b93bf63a 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -6,13 +6,16 @@ from fastapi import APIRouter, HTTPException -import server.api.core.databases as core_databases +import server.api.utils.databases as utils_databases from common import schema from common import logging_config logger = logging_config.logging.getLogger("endpoints.v1.databases") +# Validate the DEFAULT Databases +_ = utils_databases.get_databases(db_name="DEFAULT", validate=True) + auth = APIRouter() @@ -25,7 +28,7 @@ async def databases_list() -> list[schema.Database]: """List all databases""" logger.debug("Received databases_list") try: - database_objects = core_databases.get_databases() + database_objects = utils_databases.get_databases(validate=False) except ValueError as ex: # This is a problem, there should always be a "DEFAULT" database even if not configured raise HTTPException(status_code=404, detail=f"Database: {str(ex)}.") from ex @@ -42,7 +45,8 @@ async def databases_get(name: schema.DatabaseNameType) -> schema.Database: """Get single database""" logger.debug("Received databases_get - name: %s", name) try: - db = core_databases.get_databases(name) + # Validate when looking at a single database + db = utils_databases.get_databases(db_name=name, validate=True) except ValueError as ex: raise HTTPException(status_code=404, detail=f"Database: {str(ex)}.") from ex @@ -62,29 +66,39 @@ async def databases_update( logger.debug("Received databases_update - name: %s; payload: %s", name, payload) try: - db = core_databases.get_databases(name) + db = utils_databases.get_databases(db_name=name, validate=False) except ValueError as ex: raise HTTPException(status_code=404, detail=f"Database: {str(ex)}.") from ex + db.connected = False try: payload.config_dir = db.config_dir payload.wallet_location = db.wallet_location logger.debug("Testing Payload: %s", payload) - db_conn = core_databases.connect(payload) - except core_databases.DbException as ex: - db.connected = False - raise HTTPException(status_code=ex.status_code, detail=f"Database: {name} {ex.detail}.") from ex - + db_conn = utils_databases.connect(payload) + except (ValueError, PermissionError, ConnectionError, LookupError) as ex: + status_code = 500 + if isinstance(ex, ValueError): + status_code = 400 + elif isinstance(ex, PermissionError): + status_code = 401 + elif isinstance(ex, LookupError): + status_code = 404 + elif isinstance(ex, ConnectionError): + status_code = 503 + else: + raise + raise HTTPException(status_code=status_code, detail=f"Database: {db.name} {ex}.") from ex for key, value in payload.model_dump().items(): setattr(db, key, value) + + # Manage Connections; Unset and disconnect other databases db.connected = True db.set_connection(db_conn) - - # Unset and disconnect other databases - database_objects = core_databases.get_databases(validate=False) + database_objects = utils_databases.get_databases() for other_db in database_objects: if other_db.name != name and other_db.connection: - other_db.set_connection(core_databases.disconnect(db.connection)) + other_db.set_connection(utils_databases.disconnect(db.connection)) other_db.connected = False return db diff --git a/src/server/api/v1/embed.py b/src/server/api/v1/embed.py index 7f41a315..ee44ddb3 100644 --- a/src/server/api/v1/embed.py +++ b/src/server/api/v1/embed.py @@ -14,7 +14,6 @@ from pydantic import HttpUrl import requests -import server.api.core.databases as core_databases import server.api.core.oci as core_oci import server.api.utils.databases as utils_databases @@ -39,10 +38,10 @@ async def embed_drop_vs( """Drop Vector Storage""" logger.debug("Received %s embed_drop_vs: %s", client, vs) try: - client_db = utils_databases.get_client_db(client) - db_conn = core_databases.connect(client_db) + client_db = utils_databases.get_client_database(client) + db_conn = utils_databases.connect(client_db) utils_databases.drop_vs(db_conn, vs) - except core_databases.DbException as ex: + except utils_databases.DbException as ex: raise HTTPException(status_code=400, detail=f"Embed: {str(ex)}.") from ex return JSONResponse(status_code=200, content={"message": f"Vector Store: {vs} dropped."}) @@ -149,7 +148,7 @@ async def split_embed( utils_embed.populate_vs( vector_store=request, - db_details=utils_databases.get_client_db(client), + db_details=utils_databases.get_client_database(client), embed_client=embed_client, input_data=split_docos, rate_limit=rate_limit, diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 61fe1c5a..311a830c 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -64,7 +64,7 @@ async def oci_list_regions( oci_config = await oci_get(auth_profile=auth_profile) regions = utils_oci.get_regions(oci_config) return regions - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -82,7 +82,7 @@ async def oci_list_genai( oci_config = await oci_get(auth_profile=auth_profile) all_models = utils_oci.get_genai_models(oci_config, regional=False) return all_models - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -100,7 +100,7 @@ async def oci_list_compartments( oci_config = await oci_get(auth_profile=auth_profile) compartments = utils_oci.get_compartments(oci_config) return compartments - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -120,7 +120,7 @@ async def oci_list_buckets( oci_config = await oci_get(auth_profile=auth_profile) buckets = utils_oci.get_buckets(compartment_obj.ocid, oci_config) return buckets - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -139,7 +139,7 @@ async def oci_list_bucket_objects( oci_config = await oci_get(auth_profile=auth_profile) objects = utils_oci.get_bucket_objects(bucket_name, oci_config) return objects - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex @@ -163,7 +163,7 @@ async def oci_profile_update( for key, value in payload.model_dump().items(): if value not in ("", None): setattr(oci_config, key, value) - except core_oci.OciException as ex: + except utils_oci.OciException as ex: oci_config.namespace = None raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex except AttributeError as ex: @@ -214,5 +214,5 @@ async def oci_create_genai_models( oci_config = await oci_get(auth_profile=auth_profile) enabled_models = utils_models.create_genai(oci_config) return enabled_models - except core_oci.OciException as ex: + except utils_oci.OciException as ex: raise HTTPException(status_code=ex.status_code, detail=f"OCI: {ex.detail}.") from ex diff --git a/src/server/api/v1/selectai.py b/src/server/api/v1/selectai.py index 5c892f29..f712ffe9 100644 --- a/src/server/api/v1/selectai.py +++ b/src/server/api/v1/selectai.py @@ -29,8 +29,8 @@ async def selectai_get_objects( ) -> list[schema.DatabaseSelectAIObjects]: """Get DatabaseSelectAIObjects""" client_settings = core_settings.get_client_settings(client) - db_conn = utils_databases.get_client_db(client).connection - select_ai_objects = utils_selectai.get_objects(db_conn, client_settings.selectai.profile) + database = utils_databases.get_client_database(client=client, validate=False) + select_ai_objects = utils_selectai.get_objects(database.connection, client_settings.selectai.profile) return select_ai_objects @@ -47,6 +47,6 @@ async def selectai_update_objects( logger.debug("Received selectai_update - payload: %s", payload) client_settings = core_settings.get_client_settings(client) object_list = json.dumps([obj.model_dump(include={"owner", "name"}) for obj in payload]) - db_conn = utils_databases.get_client_db(client).connection + db_conn = utils_databases.get_client_database(client).connection utils_selectai.set_profile(db_conn, client_settings.selectai.profile, "object_list", object_list) return utils_selectai.get_objects(db_conn, client_settings.selectai.profile) diff --git a/src/server/api/v1/testbed.py b/src/server/api/v1/testbed.py index bd066b64..0f38fbce 100644 --- a/src/server/api/v1/testbed.py +++ b/src/server/api/v1/testbed.py @@ -44,7 +44,7 @@ async def testbed_testsets( client: schema.ClientIdType = Header(default="server"), ) -> list[schema.TestSets]: """Get a list of stored TestSets, create TestSet objects if they don't exist""" - testsets = utils_testbed.get_testsets(db_conn=utils_databases.get_client_db(client).connection) + testsets = utils_testbed.get_testsets(db_conn=utils_databases.get_client_database(client).connection) return testsets @@ -59,7 +59,7 @@ async def testbed_evaluations( ) -> list[schema.Evaluation]: """Get Evaluations""" evaluations = utils_testbed.get_evaluations( - db_conn=utils_databases.get_client_db(client).connection, tid=tid.upper() + db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper() ) return evaluations @@ -75,7 +75,7 @@ async def testbed_evaluation( ) -> schema.EvaluationReport: """Get Evaluations""" evaluation = utils_testbed.process_report( - db_conn=utils_databases.get_client_db(client).connection, eid=eid.upper() + db_conn=utils_databases.get_client_database(client).connection, eid=eid.upper() ) return evaluation @@ -90,7 +90,7 @@ async def testbed_testset_qa( client: schema.ClientIdType = Header(default="server"), ) -> schema.TestSetQA: """Get TestSet Q&A""" - return utils_testbed.get_testset_qa(db_conn=utils_databases.get_client_db(client).connection, tid=tid.upper()) + return utils_testbed.get_testset_qa(db_conn=utils_databases.get_client_database(client).connection, tid=tid.upper()) @auth.delete( @@ -102,7 +102,7 @@ async def testbed_delete_testset( client: schema.ClientIdType = Header(default="server"), ) -> JSONResponse: """Delete TestSet""" - utils_testbed.delete_qa(utils_databases.get_client_db(client).connection, tid.upper()) + utils_testbed.delete_qa(utils_databases.get_client_database(client).connection, tid.upper()) return JSONResponse(status_code=200, content={"message": f"TestSet: {tid} deleted."}) @@ -119,7 +119,7 @@ async def testbed_upsert_testsets( ) -> schema.TestSetQA: """Update stored TestSet data""" created = datetime.now().isoformat() - db_conn = utils_databases.get_client_db(client).connection + db_conn = utils_databases.get_client_database(client).connection try: for file in files: file_content = await file.read() @@ -222,7 +222,7 @@ def get_answer(question: str): # Change Grade vector_search client_settings.vector_search.grading = False - db_conn = utils_databases.get_client_db(client).connection + db_conn = utils_databases.get_client_database(client).connection testset = utils_testbed.get_testset_qa(db_conn=db_conn, tid=tid.upper()) qa_test = "\n".join(json.dumps(item) for item in testset.qa_data) temp_directory = utils_embed.get_temp_directory(client, "testbed") diff --git a/src/server/bootstrap/databases.py b/src/server/bootstrap/databases.py index 1a97a18a..cb2b7b1d 100644 --- a/src/server/bootstrap/databases.py +++ b/src/server/bootstrap/databases.py @@ -7,7 +7,6 @@ import os from server.bootstrap.configfile import ConfigStore -import server.api.core.databases as core_databases from common.schema import Database from common import logging_config @@ -28,7 +27,7 @@ def main() -> list[Database]: raise ValueError(f"Duplicate database name found in config: '{db.name}'") seen.add(db_name_lower) - db_objects = [] + database_objects = [] default_found = False for db in db_configs: @@ -46,9 +45,9 @@ def main() -> list[Database]: if updated.wallet_password: updated.wallet_location = updated.config_dir logger.info("Setting WALLET_LOCATION: %s", updated.config_dir) - db_objects.append(updated) + database_objects.append(updated) else: - db_objects.append(db) + database_objects.append(db) # If DEFAULT wasn't in config, create it from env vars if not default_found: @@ -63,26 +62,7 @@ def main() -> list[Database]: if data["wallet_password"]: data["wallet_location"] = data["config_dir"] logger.info("Setting WALLET_LOCATION: %s", data["config_dir"]) - db_objects.append(Database(**data)) - - # Validate Configuration and set vector_stores/status - database_objects = [] - for db in db_objects: - database_objects.append(db) - try: - conn = core_databases.connect(db) - db.connected = True - except core_databases.DbException: - db.connected = False - continue - db.vector_stores = core_databases.get_vs(conn) - db.selectai = core_databases.selectai_enabled(conn) - if db.selectai: - db.selectai_profiles = core_databases.get_selectai_profiles(conn) - if not db.connection and len(database_objects) > 1: - db.set_connection = core_databases.disconnect(conn) - else: - db.set_connection(conn) + database_objects.append(Database(**data)) logger.debug("Bootstrapped Databases: %s", database_objects) logger.debug("*** Bootstrapping Database - End") diff --git a/src/server/patches/litellm_patch.py b/src/server/patches/litellm_patch.py index 175f4b3f..de15f0fa 100644 --- a/src/server/patches/litellm_patch.py +++ b/src/server/patches/litellm_patch.py @@ -40,7 +40,10 @@ def custom_transform_response( api_key: Optional[str] = None, json_mode: Optional[bool] = None, ): - """Custom transform response from .venv/lib/python3.11/site-packages/litellm/llms/ollama/completion/transformation.py""" + """ + Custom transform response from + .venv/lib/python3.11/site-packages/litellm/llms/ollama/completion/transformation.py + """ logger.info("Custom transform_response is running") response_json = raw_response.json() diff --git a/tests/integration/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py index f99f7cfd..39eadb58 100644 --- a/tests/integration/client/content/config/tabs/test_databases.py +++ b/tests/integration/client/content/config/tabs/test_databases.py @@ -129,7 +129,7 @@ def test_connected(self, app_server, app_test, db_container): "username": TEST_CONFIG["db_username"], "password": TEST_CONFIG["db_password"], "dsn": "WRONG_TP", - "expected": "Update Failed - Database: DEFAULT DPY-*", + "expected": "Update Failed - Database: DEFAULT not resolvable.", }, id="bad_dsn", ), From 7bfa8272eefefb6553bc32dae2660bc75db9d963 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 17:36:07 +0100 Subject: [PATCH 17/31] resolve oci cyclic imports --- src/server/api/core/oci.py | 2 +- src/server/api/utils/oci.py | 10 ++++++---- src/server/api/v1/databases.py | 1 + src/server/api/v1/oci.py | 8 ++++++++ src/server/bootstrap/oci.py | 17 ----------------- 5 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/server/api/core/oci.py b/src/server/api/core/oci.py index e235cb11..07b863f6 100644 --- a/src/server/api/core/oci.py +++ b/src/server/api/core/oci.py @@ -31,7 +31,7 @@ def get_oci( raise ValueError("provide either 'client' or 'auth_profile', not both") oci_objects = bootstrap.OCI_OBJECTS - + print(f"********** {oci_objects}") if client is not None: client_settings = settings.get_client_settings(client) derived_auth_profile = ( diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 6afd0257..60c465e1 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -15,6 +15,7 @@ logger = logging_config.logging.getLogger("api.utils.oci") + ##################################################### # Exceptions ##################################################### @@ -26,6 +27,7 @@ def __init__(self, status_code: int, detail: str): self.detail = detail super().__init__(detail) + ##################################################### # Functions ##################################################### @@ -94,14 +96,14 @@ def init_genai_client(config: OracleCloudSettings) -> oci.generative_ai_inferenc return init_client(client_type, config) -def get_namespace(config: OracleCloudSettings = None) -> str: +def get_namespace(config: OracleCloudSettings) -> str: """Get the Object Storage Namespace. Also used for testing AuthN""" logger.info("Getting Object Storage Namespace") client_type = oci.object_storage.ObjectStorageClient try: client = init_client(client_type, config) - namespace = client.get_namespace().data - logger.info("OCI: Namespace = %s", namespace) + config.namespace = client.get_namespace().data + logger.info("OCI: Namespace = %s", config.namespace) except oci.exceptions.InvalidConfig as ex: raise OciException(status_code=400, detail="Invalid Config") from ex except oci.exceptions.ServiceError as ex: @@ -115,7 +117,7 @@ def get_namespace(config: OracleCloudSettings = None) -> str: except Exception as ex: raise OciException(status_code=500, detail=str(ex)) from ex - return namespace + return config.namespace def get_regions(config: OracleCloudSettings = None) -> list[dict]: diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index b93bf63a..87b2a6f3 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -16,6 +16,7 @@ # Validate the DEFAULT Databases _ = utils_databases.get_databases(db_name="DEFAULT", validate=True) + auth = APIRouter() diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 311a830c..3072bac6 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -17,6 +17,14 @@ logger = logging_config.logging.getLogger("endpoints.v1.oci") +# Validate the DEFAULT OCI Profile and get models +try: + default_config = core_oci.get_oci(auth_profile="DEFAULT") + _ = utils_oci.get_namespace(config=default_config) + _ = utils_models.create_genai(default_config) +except utils_oci.OciException: + pass + auth = APIRouter() diff --git a/src/server/bootstrap/oci.py b/src/server/bootstrap/oci.py index 8f123be1..c40c297f 100644 --- a/src/server/bootstrap/oci.py +++ b/src/server/bootstrap/oci.py @@ -9,8 +9,6 @@ import oci from server.bootstrap.configfile import ConfigStore -import server.api.utils.oci as utils_oci -import server.api.utils.models as utils_models from common import logging_config from common.schema import OracleCloudSettings @@ -110,21 +108,6 @@ def override(profile: dict, key: str, env_key: str, env: dict, overrides: dict, oci_config = OracleCloudSettings(**profile_data) oci_objects.append(oci_config) - if oci_config.auth_profile == oci.config.DEFAULT_PROFILE: - try: - oci_config.namespace = utils_oci.get_namespace(oci_config) - except Exception: - logger.warning("Failed to get namespace for DEFAULT OCI profile") - continue - - # Attempt to load OCI GenAI Models after OCI and MODELs are Bootstrapped - try: - oci_config = [o for o in oci_objects if o.auth_profile == "DEFAULT"] - if oci_config: - utils_models.create_genai(oci_config[0]) - except Exception as ex: - logger.info("Unable to bootstrap OCI GenAI Models: %s", str(ex)) - logger.debug("*** Bootstrapping OCI - End") return oci_objects From 325871e2e633713bfc24601787a848805ff6e10f Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 18:12:41 +0100 Subject: [PATCH 18/31] Update SpringAI --- src/client/content/config/tabs/databases.py | 9 +++++++-- src/client/content/config/tabs/settings.py | 3 ++- src/client/content/tools/tabs/split_embed.py | 2 +- src/client/spring_ai/templates/obaas.yaml | 2 +- src/client/spring_ai/templates/start.sh | 4 ++-- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/client/content/config/tabs/databases.py b/src/client/content/config/tabs/databases.py index 06b1bd79..6d393b0a 100644 --- a/src/client/content/config/tabs/databases.py +++ b/src/client/content/config/tabs/databases.py @@ -22,11 +22,16 @@ ##################################################### # Functions ##################################################### -def get_databases(force: bool = False) -> None: +def get_databases(validate: bool = False, force: bool = False) -> None: """Get Databases from API Server""" if force or "database_configs" not in state or not state.database_configs: try: logger.info("Refreshing state.database_configs") + # Validation will be done on currently configured client database + # validation includes new vector_stores, etc. + if validate: + client_database = state.client_settings.get("database", {}).get("alias", {}) + _ = api_call.get(endpoint=f"v1/databases/{client_database}") state.database_configs = api_call.get(endpoint="v1/databases") except api_call.ApiError as ex: logger.error("Unable to populate state.database_configs: %s", ex) @@ -61,7 +66,7 @@ def patch_database(name: str, supplied: dict, connected: bool) -> bool: def drop_vs(vs: dict) -> None: """Drop a Vector Storage Table""" api_call.delete(endpoint=f"v1/embed/{vs['vector_store']}") - get_databases(force=True) + get_databases(validate=True, force=True) def select_ai_profile() -> None: diff --git a/src/client/content/config/tabs/settings.py b/src/client/content/config/tabs/settings.py index 90c18cc0..c71832d7 100644 --- a/src/client/content/config/tabs/settings.py +++ b/src/client/content/config/tabs/settings.py @@ -328,12 +328,13 @@ def display_settings(): st.header("Source Code Templates", divider="red") # Merge the User Settings into the Model Config - model_lookup = st_common.state_configs_lookup("model_configs", "id") try: + model_lookup = st_common.enabled_models_lookup(model_type="ll") ll_config = model_lookup[state.client_settings["ll_model"]["model"]] | state.client_settings["ll_model"] except KeyError: ll_config = {} try: + model_lookup = st_common.enabled_models_lookup(model_type="embed") embed_config = ( model_lookup[state.client_settings["vector_search"]["model"]] | state.client_settings["vector_search"] ) diff --git a/src/client/content/tools/tabs/split_embed.py b/src/client/content/tools/tabs/split_embed.py index 07a1640b..d984841b 100644 --- a/src/client/content/tools/tabs/split_embed.py +++ b/src/client/content/tools/tabs/split_embed.py @@ -391,7 +391,7 @@ def display_split_embed() -> None: ) st.success(f"Vector Store Populated: {response['message']}", icon="✅") # Refresh database_configs state to reflect new vector stores - get_databases(force="True") + get_databases(validate=True, force=True) except api_call.ApiError as ex: st.error(ex, icon="🚨") diff --git a/src/client/spring_ai/templates/obaas.yaml b/src/client/spring_ai/templates/obaas.yaml index 5ad36615..ae25cae4 100644 --- a/src/client/spring_ai/templates/obaas.yaml +++ b/src/client/spring_ai/templates/obaas.yaml @@ -14,7 +14,7 @@ spring: initialize-schema: True index-type: {vector_search[index_type]} openai: - base-url: \"{ll_model[url]}\" + base-url: \"{ll_model[api_base]}\" api-key: \"{ll_model[api_key]}\" chat: options: diff --git a/src/client/spring_ai/templates/start.sh b/src/client/spring_ai/templates/start.sh index 30341432..33ab7339 100644 --- a/src/client/spring_ai/templates/start.sh +++ b/src/client/spring_ai/templates/start.sh @@ -6,14 +6,14 @@ if [[ "{provider}" == "ollama" ]]; then export OPENAI_CHAT_MODEL="" unset OPENAI_EMBEDDING_MODEL unset OPENAI_URL - export OLLAMA_BASE_URL="{ll_model[url]}" + export OLLAMA_BASE_URL="{ll_model[api_base]}" export OLLAMA_CHAT_MODEL="{ll_model[model]}" export OLLAMA_EMBEDDING_MODEL="{vector_search[model]}" else PREFIX="OP"; UNSET_PREFIX="OL" export OPENAI_CHAT_MODEL="{ll_model[model]}" export OPENAI_EMBEDDING_MODEL="{vector_search[model]}" - export OPENAI_URL="{ll_model[url]}" + export OPENAI_URL="{ll_model[api_base]}" export OLLAMA_CHAT_MODEL="" unset OLLAMA_EMBEDDING_MODEL fi From 7b70525d0e57d422b9011d2c5fbe7cbb3fe8a8d8 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 19:49:25 +0100 Subject: [PATCH 19/31] bump versions --- src/pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pyproject.toml b/src/pyproject.toml index 14b14d77..b0675ea6 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -43,11 +43,11 @@ server = [ "langchain-openai==0.3.32", "langchain-perplexity==0.1.2", "langchain-xai==0.2.5", - "langgraph==0.6.4", + "langgraph==0.6.6", "litellm==1.76.1", - "llama-index==0.13.1", + "llama-index==0.13.3", "lxml==6.0.0", - "matplotlib==3.10.5", + "matplotlib==3.10.6", "oci~=2.0", "psutil==7.0.0", "python-multipart==0.0.20", @@ -63,6 +63,7 @@ client = [ # Test dependencies test = [ + "pylint", "pytest", "pytest-asyncio", "docker", From 1e47c940ebd4502a02a4edbeedc02a60b0853566 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Mon, 1 Sep 2025 20:34:48 +0100 Subject: [PATCH 20/31] Fix tests --- src/server/api/utils/databases.py | 2 +- src/server/api/v1/databases.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 3388a943..0555c919 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -141,7 +141,7 @@ def connect(config: Database) -> oracledb.Connection: raise PermissionError("invalid credentials") from ex if "DPY-6005" in str(ex): raise ConnectionError("unable to connect") from ex - if "DPY-4000" in str(ex): + if any(code in str(ex) for code in ("DPY-4000", "DPY-4026")): raise LookupError("not resolvable") from ex raise logger.debug("Connected to Databases: %s", config.dsn) diff --git a/src/server/api/v1/databases.py b/src/server/api/v1/databases.py index 87b2a6f3..221c2ccf 100644 --- a/src/server/api/v1/databases.py +++ b/src/server/api/v1/databases.py @@ -14,8 +14,10 @@ logger = logging_config.logging.getLogger("endpoints.v1.databases") # Validate the DEFAULT Databases -_ = utils_databases.get_databases(db_name="DEFAULT", validate=True) - +try: + _ = utils_databases.get_databases(db_name="DEFAULT", validate=True) +except Exception: + pass auth = APIRouter() From ce5d07a48218112c81a2ef245593b83ef9fea9bb Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 07:52:55 +0100 Subject: [PATCH 21/31] Resolves @cjbj conversation --- opentofu/modules/vm/iam.tf | 4 ++++ src/server/api/utils/databases.py | 18 ++++++++++++------ .../content/config/tabs/test_databases.py | 12 ++++++------ .../server/test_endpoints_databases.py | 11 +++++++---- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/opentofu/modules/vm/iam.tf b/opentofu/modules/vm/iam.tf index 8ecd2382..22cf0bd4 100644 --- a/opentofu/modules/vm/iam.tf +++ b/opentofu/modules/vm/iam.tf @@ -34,6 +34,10 @@ resource "oci_identity_policy" "identity_node_policies" { "allow dynamic-group %s to read objects in compartment id %s", oci_identity_dynamic_group.compute_dynamic_group.name, var.compartment_id ), + format( + "allow dynamic-group %s to use generative-ai-family in compartment id %s", + oci_identity_dynamic_group.workers_dynamic_group.name, var.compartment_id + ), ] provider = oci.home_region } \ No newline at end of file diff --git a/src/server/api/utils/databases.py b/src/server/api/utils/databases.py index 0555c919..55f7c649 100644 --- a/src/server/api/utils/databases.py +++ b/src/server/api/utils/databases.py @@ -137,12 +137,17 @@ def connect(config: Database) -> oracledb.Connection: logger.debug("Attempting Database Connection...") conn = oracledb.connect(**db_authn) except oracledb.DatabaseError as ex: - if "ORA-01017" in str(ex): - raise PermissionError("invalid credentials") from ex - if "DPY-6005" in str(ex): - raise ConnectionError("unable to connect") from ex - if any(code in str(ex) for code in ("DPY-4000", "DPY-4026")): - raise LookupError("not resolvable") from ex + error = ex.args[0] if ex.args else None + code = getattr(error, "full_code", None) + mapping = { + "ORA-01017": PermissionError, + "DPY-6005": ConnectionError, + "DPY-4000": LookupError, + "DPY-4026": LookupError, + } + if code in mapping: + raise mapping[code](f"- {error.message}") from ex + # If not recognized, re-raise untouched raise logger.debug("Connected to Databases: %s", config.dsn) @@ -207,6 +212,7 @@ def execute_sql(conn: oracledb.Connection, run_sql: str, binds: dict = None) -> return rows + def drop_vs(conn: oracledb.Connection, vs: VectorStoreTableType) -> None: """Drop Vector Storage""" logger.info("Dropping Vector Store: %s", vs) diff --git a/tests/integration/client/content/config/tabs/test_databases.py b/tests/integration/client/content/config/tabs/test_databases.py index 39eadb58..263d6426 100644 --- a/tests/integration/client/content/config/tabs/test_databases.py +++ b/tests/integration/client/content/config/tabs/test_databases.py @@ -60,7 +60,7 @@ def test_wrong_details(self, app_server, app_test): at.button(key="save_database").click().run() assert at.error[0].value == "Current Status: Disconnected" - assert at.error[1].value == "Update Failed - Database: DEFAULT unable to connect." and at.error[1].icon == "🚨" + assert "cannot connect to database" in at.error[1].value and at.error[1].icon == "🚨" def test_connected(self, app_server, app_test, db_container): """Sumbits with good DSN""" @@ -99,7 +99,7 @@ def test_connected(self, app_server, app_test, db_container): "username": "ADMIN", "password": TEST_CONFIG["db_password"], "dsn": TEST_CONFIG["db_dsn"], - "expected": "Update Failed - Database: DEFAULT invalid credentials.", + "expected": "invalid credential or not authorized", }, id="bad_user", ), @@ -109,7 +109,7 @@ def test_connected(self, app_server, app_test, db_container): "username": TEST_CONFIG["db_username"], "password": "Wr0ng_P4ssW0rd", "dsn": TEST_CONFIG["db_dsn"], - "expected": "Update Failed - Database: DEFAULT invalid credentials.", + "expected": "invalid credential or not authorized", }, id="bad_password", ), @@ -119,7 +119,7 @@ def test_connected(self, app_server, app_test, db_container): "username": TEST_CONFIG["db_username"], "password": TEST_CONFIG["db_password"], "dsn": "//localhost:1521/WRONG_TP", - "expected": "Update Failed - Database: DEFAULT unable to connect.", + "expected": "cannot connect to database", }, id="bad_dsn_easy", ), @@ -129,7 +129,7 @@ def test_connected(self, app_server, app_test, db_container): "username": TEST_CONFIG["db_username"], "password": TEST_CONFIG["db_password"], "dsn": "WRONG_TP", - "expected": "Update Failed - Database: DEFAULT not resolvable.", + "expected": "DPY-4026", }, id="bad_dsn", ), @@ -147,7 +147,7 @@ def test_disconnected(self, app_server, app_test, db_container, test_case): at.text_input(key="database_dsn").set_value(test_case["dsn"]).run() at.button(key="save_database").click().run() assert at.error[0].value == "Current Status: Disconnected" - assert re.match(test_case["expected"], at.error[1].value) and at.error[1].icon == "🚨" + assert test_case["expected"] in at.error[1].value and at.error[1].icon == "🚨" # Due to the connection error, the settings should NOT be updated and be set # to previous successful test connection; connected will be False for error handling assert at.session_state.database_configs[0]["name"] == "DEFAULT" diff --git a/tests/integration/server/test_endpoints_databases.py b/tests/integration/server/test_endpoints_databases.py index 8a0317d2..3b341afe 100644 --- a/tests/integration/server/test_endpoints_databases.py +++ b/tests/integration/server/test_endpoints_databases.py @@ -84,7 +84,7 @@ def test_databases_update_db_down(self, client, auth_headers): } response = client.patch("/v1/databases/DEFAULT", headers=auth_headers["valid_auth"], json=payload) assert response.status_code == 503 - assert response.json() == {"detail": "Database: DEFAULT unable to connect."} + assert "cannot connect to database" in response.json().get("detail", "") test_cases = [ pytest.param( @@ -125,7 +125,7 @@ def test_databases_update_db_down(self, client, auth_headers): "DEFAULT", 503, {"user": "user", "password": "password", "dsn": "//localhost:1521/dsn"}, - {"detail": "Database: DEFAULT unable to connect."}, + {"detail": "cannot connect to database"}, id="invalid_connection", ), pytest.param( @@ -136,7 +136,7 @@ def test_databases_update_db_down(self, client, auth_headers): "password": "Wr0ng_P4sswOrd", "dsn": TEST_CONFIG["db_dsn"], }, - {"detail": "Database: DEFAULT invalid credentials."}, + {"detail": "invalid credential or not authorized"}, id="wrong_password", ), pytest.param( @@ -174,7 +174,10 @@ def test_databases_update_cases( assert response.status_code == status_code if response.status_code != 200: - assert response.json() == expected + if response.status_code == 422: + assert response.json() == expected + else: + assert expected["detail"] in response.json().get("detail", "") else: data = response.json() data.pop("config_dir", None) # Remove config_dir as it's environment-specific From 7b1e7033ed3db52220dfd0c7e28fd80796d08c16 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 07:55:08 +0100 Subject: [PATCH 22/31] Linted OpenTofu --- opentofu/modules/vm/iam.tf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/opentofu/modules/vm/iam.tf b/opentofu/modules/vm/iam.tf index 22cf0bd4..b4b913ac 100644 --- a/opentofu/modules/vm/iam.tf +++ b/opentofu/modules/vm/iam.tf @@ -35,7 +35,7 @@ resource "oci_identity_policy" "identity_node_policies" { oci_identity_dynamic_group.compute_dynamic_group.name, var.compartment_id ), format( - "allow dynamic-group %s to use generative-ai-family in compartment id %s", + "allow dynamic-group %s to use generative-ai-family in compartment id %s", oci_identity_dynamic_group.workers_dynamic_group.name, var.compartment_id ), ] From 78e7c2cefa4849918d8822937fe9139c0bf22ced Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 08:13:26 +0100 Subject: [PATCH 23/31] IaC Updates --- opentofu/main.tf | 1 + opentofu/modules/vm/iam.tf | 2 +- opentofu/modules/vm/locals.tf | 13 +++++++------ opentofu/modules/vm/templates/cloudinit-compute.tpl | 9 +++++++-- opentofu/modules/vm/variables.tf | 4 ++++ opentofu/schema.yaml | 1 + opentofu/variables.tf | 10 ++++++++++ 7 files changed, 31 insertions(+), 9 deletions(-) diff --git a/opentofu/main.tf b/opentofu/main.tf index e0c0ded1..44ddfed9 100644 --- a/opentofu/main.tf +++ b/opentofu/main.tf @@ -80,6 +80,7 @@ resource "oci_database_autonomous_database" "default_adb" { module "vm" { count = var.infrastructure == "VM" ? 1 : 0 source = "./modules/vm" + optimizer_version = var.optimizer_version label_prefix = local.label_prefix tenancy_id = var.tenancy_ocid compartment_id = local.compartment_ocid diff --git a/opentofu/modules/vm/iam.tf b/opentofu/modules/vm/iam.tf index b4b913ac..70864c53 100644 --- a/opentofu/modules/vm/iam.tf +++ b/opentofu/modules/vm/iam.tf @@ -36,7 +36,7 @@ resource "oci_identity_policy" "identity_node_policies" { ), format( "allow dynamic-group %s to use generative-ai-family in compartment id %s", - oci_identity_dynamic_group.workers_dynamic_group.name, var.compartment_id + oci_identity_dynamic_group.compute_dynamic_group.name, var.compartment_id ), ] provider = oci.home_region diff --git a/opentofu/modules/vm/locals.tf b/opentofu/modules/vm/locals.tf index 00fa9d23..90bc2b61 100644 --- a/opentofu/modules/vm/locals.tf +++ b/opentofu/modules/vm/locals.tf @@ -4,12 +4,13 @@ locals { cloud_init_compute = templatefile("${path.module}/templates/cloudinit-compute.tpl", { - tenancy_id = var.tenancy_id - compartment_id = var.compartment_id - oci_region = var.region - db_name = var.adb_name - db_password = var.adb_password - install_ollama = var.vm_is_gpu_shape ? true : false + tenancy_id = var.tenancy_id + compartment_id = var.compartment_id + oci_region = var.region + db_name = var.adb_name + db_password = var.adb_password + optimizer_version = var.optimizer_version + install_ollama = var.vm_is_gpu_shape ? true : false }) cloud_init_database = templatefile("${path.module}/templates/cloudinit-database.tpl", { diff --git a/opentofu/modules/vm/templates/cloudinit-compute.tpl b/opentofu/modules/vm/templates/cloudinit-compute.tpl index e5619ba6..603f3aa3 100644 --- a/opentofu/modules/vm/templates/cloudinit-compute.tpl +++ b/opentofu/modules/vm/templates/cloudinit-compute.tpl @@ -60,8 +60,13 @@ write_files: # Setup for Instance Principles # Download/Setup Source Code - curl -L https://github.com/oracle/ai-optimizer/releases/latest/download/ai-optimizer-src.tar.gz \ - | tar -xz -C /app + if [ "${optimizer_version}" = "main" ]; then + URL="https://github.com/oracle/ai-optimizer/archive/refs/heads/main.tar.gz" + else + URL="https://github.com/oracle/ai-optimizer/releases/latest/download/ai-optimizer-src.tar.gz" + fi + # Download and extract + curl -L "$URL" | tar -xz -C /app cd /app python3.11 -m venv .venv source .venv/bin/activate diff --git a/opentofu/modules/vm/variables.tf b/opentofu/modules/vm/variables.tf index 1fb6e75e..bd15033c 100644 --- a/opentofu/modules/vm/variables.tf +++ b/opentofu/modules/vm/variables.tf @@ -2,6 +2,10 @@ # All rights reserved. The Universal Permissive License (UPL), Version 1.0 as shown at http://oss.oracle.com/licenses/upl # spell-checker: disable +variable "optimizer_version" { + type = string +} + variable "tenancy_id" { type = string } diff --git a/opentofu/schema.yaml b/opentofu/schema.yaml index 8fc21bbb..0f1f8d05 100644 --- a/opentofu/schema.yaml +++ b/opentofu/schema.yaml @@ -28,6 +28,7 @@ variableGroups: - title: "Hidden (Defaults)" variables: + - optimizer_version - adb_version - k8s_run_cfgmgt visible: false diff --git a/opentofu/variables.tf b/opentofu/variables.tf index c18efd79..9f844aee 100644 --- a/opentofu/variables.tf +++ b/opentofu/variables.tf @@ -3,6 +3,16 @@ # spell-checker: disable // Standard Default Vars +variable "optimizer_version" { + description = "Determines if latest release or main is used" + type = string + default = "latest" + validation { + condition = var.optimizer_version == "latest" || var.optimizer_version == "main" + error_message = "optimizer_version must be either 'latest' or 'main'." + } +} + variable "tenancy_ocid" { description = "The Tenancy ID of the OCI Cloud Account in which to create the resources." type = string From 9c5a82039a92dece9e5751fe52972738256452a3 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 10:45:52 +0100 Subject: [PATCH 24/31] Updates for instance principals --- .../vm/templates/cloudinit-compute.tpl | 6 +- src/client/content/api_server.py | 12 +-- src/client/content/config/tabs/oci.py | 92 +++++++++++-------- src/launch_client.py | 11 ++- src/launch_server.py | 6 +- src/server/api/utils/oci.py | 13 ++- src/server/api/v1/oci.py | 4 +- 7 files changed, 82 insertions(+), 62 deletions(-) diff --git a/opentofu/modules/vm/templates/cloudinit-compute.tpl b/opentofu/modules/vm/templates/cloudinit-compute.tpl index 603f3aa3..f0e95545 100644 --- a/opentofu/modules/vm/templates/cloudinit-compute.tpl +++ b/opentofu/modules/vm/templates/cloudinit-compute.tpl @@ -57,16 +57,14 @@ write_files: permissions: '0755' content: | #!/bin/bash - # Setup for Instance Principles - # Download/Setup Source Code if [ "${optimizer_version}" = "main" ]; then URL="https://github.com/oracle/ai-optimizer/archive/refs/heads/main.tar.gz" + curl -L "$URL" | tar -xz -C /app --strip-components=2 ai-optimizer-main/src else URL="https://github.com/oracle/ai-optimizer/releases/latest/download/ai-optimizer-src.tar.gz" + curl -L "$URL" | tar -xz -C /app fi - # Download and extract - curl -L "$URL" | tar -xz -C /app cd /app python3.11 -m venv .venv source .venv/bin/activate diff --git a/src/client/content/api_server.py b/src/client/content/api_server.py index de48c91e..ddede36b 100644 --- a/src/client/content/api_server.py +++ b/src/client/content/api_server.py @@ -21,10 +21,8 @@ try: import launch_server - - REMOTE_SERVER = False except ImportError: - REMOTE_SERVER = True + pass ##################################################### @@ -52,7 +50,7 @@ def server_restart() -> None: state.server["key"] = os.getenv("API_SERVER_KEY") launch_server.stop_server(state.server["pid"]) - state.server["pid"] = launch_server.start_server(state.server["port"]) + _, state.server["pid"] = launch_server.start_server(state.server["port"]) time.sleep(10) state.pop("server_client", None) @@ -71,16 +69,16 @@ async def main() -> None: key="user_server_port", min_value=1, max_value=65535, - disabled=REMOTE_SERVER, + disabled=state.server["remote"], ) right.text_input( "API Server Key:", value=state.server["key"], key="user_server_key", type="password", - disabled=REMOTE_SERVER, + disabled=state.server["remote"], ) - if not REMOTE_SERVER: + if not state.server["remote"]: st.button("Restart Server", type="primary", on_click=server_restart) st.header("Server Settings", divider="red") diff --git a/src/client/content/config/tabs/oci.py b/src/client/content/config/tabs/oci.py index 8ac8c177..e4543da7 100644 --- a/src/client/content/config/tabs/oci.py +++ b/src/client/content/config/tabs/oci.py @@ -55,7 +55,14 @@ def patch_oci(auth_profile: str, supplied: dict, namespace: str, toast: bool = T if differences or not namespace: rerun = True try: - if supplied["security_token_file"]: + if ( + supplied.get("authentication") + not in ( + "instance_principal", + "oke_workload_identity", + ) + and supplied["security_token_file"] + ): supplied["authentication"] = "security_token" with st.spinner(text="Updating OCI Profile...", show_time=True): @@ -85,7 +92,20 @@ def display_oci() -> None: st.stop() st.subheader("Configuration") + # Store supplied values in dictionary + supplied = {} + + disable_config = False oci_lookup = st_common.state_configs_lookup("oci_configs", "auth_profile") + # Handle instance_principal and oke_workload_identity + if len(oci_lookup) == 1 and state.oci_configs[0]["authentication"] in ( + "instance_principal", + "oke_workload_identity", + ): + st.info("Using OCI Authentication Principals", icon="â„šī¸") + supplied["authentication"] = state.oci_configs[0]["authentication"] + supplied["tenancy"] = state.oci_configs[0]["tenancy"] + disable_config = True if len(oci_lookup) > 0: selected_oci_auth_profile = st.selectbox( "Profile:", @@ -93,52 +113,49 @@ def display_oci() -> None: index=list(oci_lookup.keys()).index(state.client_settings["oci"]["auth_profile"]), key="selected_oci", on_change=st_common.update_client_settings("oci"), + disabled=disable_config, ) else: selected_oci_auth_profile = "DEFAULT" - token_auth = st.checkbox( - "Use token authentication?", - key="oci_token_auth", - value=False, - ) + token_auth = st.checkbox("Use token authentication?", key="oci_token_auth", value=False, disabled=disable_config) namespace = oci_lookup[selected_oci_auth_profile]["namespace"] - # Store supplied values in dictionary - supplied = {} with st.container(border=True): - supplied["user"] = st.text_input( - "User OCID:", - value=oci_lookup[selected_oci_auth_profile]["user"], - disabled=token_auth, - key="oci_user", - ) - supplied["security_token_file"] = st.text_input( - "Security Token File:", - value=oci_lookup[selected_oci_auth_profile]["security_token_file"], - disabled=not token_auth, - key="oci_security_token_file", - ) - supplied["fingerprint"] = st.text_input( - "Fingerprint:", - value=oci_lookup[selected_oci_auth_profile]["fingerprint"], - key="oci_fingerprint", - ) - supplied["tenancy"] = st.text_input( - "Tenancy OCID:", - value=oci_lookup[selected_oci_auth_profile]["tenancy"], - key="oci_tenancy", - ) + if not disable_config: + supplied["user"] = st.text_input( + "User OCID:", + value=oci_lookup[selected_oci_auth_profile]["user"], + disabled=token_auth, + key="oci_user", + ) + supplied["security_token_file"] = st.text_input( + "Security Token File:", + value=oci_lookup[selected_oci_auth_profile]["security_token_file"], + disabled=not token_auth, + key="oci_security_token_file", + ) + supplied["key_file"] = st.text_input( + "Key File:", + value=oci_lookup[selected_oci_auth_profile]["key_file"], + key="oci_key_file", + ) + supplied["fingerprint"] = st.text_input( + "Fingerprint:", + value=oci_lookup[selected_oci_auth_profile]["fingerprint"], + key="oci_fingerprint", + ) + supplied["tenancy"] = st.text_input( + "Tenancy OCID:", + value=oci_lookup[selected_oci_auth_profile]["tenancy"], + key="oci_tenancy", + ) supplied["region"] = st.text_input( "Region:", value=oci_lookup[selected_oci_auth_profile]["region"], help="Region of Source Bucket", key="oci_region", ) - supplied["key_file"] = st.text_input( - "Key File:", - value=oci_lookup[selected_oci_auth_profile]["key_file"], - key="oci_key_file", - ) + if namespace: st.success(f"Current Status: Validated - Namespace: {namespace}") else: @@ -148,8 +165,9 @@ def display_oci() -> None: if st.button("Save Configuration", key="save_oci"): # Modify based on token usage - supplied["security_token_file"] = None if not token_auth else supplied["security_token_file"] - supplied["user"] = None if token_auth else supplied["user"] + if not disable_config: + supplied["security_token_file"] = None if not token_auth else supplied["security_token_file"] + supplied["user"] = None if token_auth else supplied["user"] if patch_oci(selected_oci_auth_profile, supplied, namespace): st.rerun() diff --git a/src/launch_client.py b/src/launch_client.py index 8f064e95..f50406dc 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -22,7 +22,7 @@ logger = logging_config.logging.getLogger("launch_client") # Import launch_server if it exists -REMOTE_SERVER = False +LAUNCH_SERVER_EXISTS = True try: from launch_server import start_server, get_api_key @@ -30,7 +30,7 @@ logger.debug("Imported API Server.") except ImportError as ex: logger.debug("API Server not present: %s", ex) - REMOTE_SERVER = True + LAUNCH_SERVER_EXISTS = False ############################################################################# @@ -43,6 +43,7 @@ def init_server_state() -> None: state.server = {"url": os.getenv("API_SERVER_URL", "http://localhost")} state.server["port"] = int(os.getenv("API_SERVER_PORT", "8000")) state.server["key"] = os.getenv("API_SERVER_KEY") + state.server["remote"] = True logger.debug("Server State: %s", state.server) @@ -153,9 +154,11 @@ def main() -> None: if __name__ == "__main__": # Start Server if not running init_server_state() - if not REMOTE_SERVER: + if LAUNCH_SERVER_EXISTS: try: logger.debug("Server PID: %i", state.server["pid"]) except KeyError: - state.server["pid"] = start_server(logfile=True) + server_state, pid = start_server(logfile=True) + state.server["pid"] = pid + state.server["remote"] = server_state != "started" main() diff --git a/src/launch_server.py b/src/launch_server.py index fde83bda..7727f834 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -48,7 +48,7 @@ ########################################## # Process Control ########################################## -def start_server(port: int = 8000, logfile: bool = False) -> int: +def start_server(port: int = 8000, logfile: bool = False) -> tuple[str, int]: """Start the uvicorn server for FastAPI""" logger.info("Starting Oracle AI Optimizer and Toolkit") @@ -96,7 +96,7 @@ def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: existing_pid = get_pid_using_port(port) if existing_pid: logger.info("API server already running on port: %i (PID: %i)", port, existing_pid) - return existing_pid + return ("existing", existing_pid) popen_queue = queue.Queue() thread = threading.Thread( @@ -105,7 +105,7 @@ def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: ) thread.start() - return popen_queue.get().pid + return ("started", popen_queue.get().pid) def stop_server(pid: int) -> None: diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 60c465e1..8071472f 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -36,12 +36,14 @@ def init_client( oci.object_storage.ObjectStorageClient, oci.identity.IdentityClient, oci.generative_ai_inference.GenerativeAiInferenceClient, + oci.generative_ai.GenerativeAiClient, ], config: OracleCloudSettings = None, ) -> Union[ oci.object_storage.ObjectStorageClient, oci.identity.IdentityClient, oci.generative_ai_inference.GenerativeAiInferenceClient, + oci.generative_ai.GenerativeAiClient, ]: """Initialize OCI Client with either user or Token""" # connection timeout to 1 seconds and the read timeout to 60 seconds @@ -51,7 +53,7 @@ def init_client( "timeout": (1, 180), } - # OCI GenAI + # OCI GenAI (for model calling) if ( client_type == oci.generative_ai_inference.GenerativeAiInferenceClient and config.genai_compartment_id @@ -91,7 +93,7 @@ def init_client( def init_genai_client(config: OracleCloudSettings) -> oci.generative_ai_inference.GenerativeAiInferenceClient: - """Initialise OCI GenAI Client""" + """Initialise OCI GenAI Client; used by models""" client_type = oci.generative_ai_inference.GenerativeAiInferenceClient return init_client(client_type, config) @@ -154,9 +156,10 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis regions = get_regions(config) for region in regions: - region_config = dict(config) - region_config["region"] = region["region_name"] - client = oci.generative_ai.GenerativeAiClient(region_config) + region_config = config + region_config.region = region["region_name"] + client_type = oci.generative_ai.GenerativeAiClient + client = init_client(client_type, region_config) logger.info( "Checking Region: %s; Compartment: %s for GenAI services", region["region_name"], diff --git a/src/server/api/v1/oci.py b/src/server/api/v1/oci.py index 3072bac6..56a7f240 100644 --- a/src/server/api/v1/oci.py +++ b/src/server/api/v1/oci.py @@ -21,7 +21,7 @@ try: default_config = core_oci.get_oci(auth_profile="DEFAULT") _ = utils_oci.get_namespace(config=default_config) - _ = utils_models.create_genai(default_config) + _ = utils_models.create_genai(config=default_config) except utils_oci.OciException: pass @@ -85,7 +85,7 @@ async def oci_list_genai( auth_profile: schema.OCIProfileType, ) -> list: """Return a list of compartments""" - logger.debug("Received oci_list_regions - auth_profile: %s", auth_profile) + logger.debug("Received oci_list_genai - auth_profile: %s", auth_profile) try: oci_config = await oci_get(auth_profile=auth_profile) all_models = utils_oci.get_genai_models(oci_config, regional=False) From a9e5be271cbec3cba8b0f2f285781b1a565f3635 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 11:41:34 +0100 Subject: [PATCH 25/31] Update IAM for instance principals --- opentofu/modules/kubernetes/iam.tf | 2 +- opentofu/modules/vm/iam.tf | 4 ++-- src/server/agents/chatbot.py | 3 --- src/server/api/core/oci.py | 1 - src/server/api/utils/oci.py | 4 ++-- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/opentofu/modules/kubernetes/iam.tf b/opentofu/modules/kubernetes/iam.tf index 24f41f5d..2b818f9d 100644 --- a/opentofu/modules/kubernetes/iam.tf +++ b/opentofu/modules/kubernetes/iam.tf @@ -42,8 +42,8 @@ resource "oci_identity_policy" "workers_policies" { format("allow any-user to read objects in compartment id %s where all {request.principal.type = 'workload', request.principal.namespace = '%s', request.principal.cluster_id = '%s'}", var.compartment_id, var.label_prefix, oci_containerengine_cluster.default_cluster.id), format("allow any-user to manage repos in compartment id %s where all {request.principal.type = 'workload', request.principal.namespace = '%s', request.principal.cluster_id = '%s'}", var.compartment_id, var.label_prefix, oci_containerengine_cluster.default_cluster.id), # Instance Principles - format("allow dynamic-group %s to use generative-ai-family in compartment id %s", oci_identity_dynamic_group.workers_dynamic_group.name, var.compartment_id), format("allow dynamic-group %s to manage repos in compartment id %s", oci_identity_dynamic_group.workers_dynamic_group.name, var.compartment_id), + format("allow dynamic-group %s to use generative-ai-family in tenancy", oci_identity_dynamic_group.workers_dynamic_group.name), ] provider = oci.home_region } \ No newline at end of file diff --git a/opentofu/modules/vm/iam.tf b/opentofu/modules/vm/iam.tf index 70864c53..cad78b27 100644 --- a/opentofu/modules/vm/iam.tf +++ b/opentofu/modules/vm/iam.tf @@ -35,8 +35,8 @@ resource "oci_identity_policy" "identity_node_policies" { oci_identity_dynamic_group.compute_dynamic_group.name, var.compartment_id ), format( - "allow dynamic-group %s to use generative-ai-family in compartment id %s", - oci_identity_dynamic_group.compute_dynamic_group.name, var.compartment_id + "allow dynamic-group %s to use generative-ai-family in tenancy", + oci_identity_dynamic_group.compute_dynamic_group.name ), ] provider = oci.home_region diff --git a/src/server/agents/chatbot.py b/src/server/agents/chatbot.py index a5b0d5e7..9b4faa89 100644 --- a/src/server/agents/chatbot.py +++ b/src/server/agents/chatbot.py @@ -117,8 +117,6 @@ def rephrase(state: OptimizerState, config: RunnableConfig) -> str: ll_raw = config["configurable"]["ll_config"] try: response = completion(messages=[{"role": "system", "content": formatted_prompt}], stream=False, **ll_raw) - print(f"************ {response}") - context_question = response.choices[0].message.content except APIConnectionError as ex: logger.error("Failed to rephrase: %s", str(ex)) @@ -183,7 +181,6 @@ async def vs_grade(state: OptimizerState, config: RunnableConfig) -> OptimizerSt response = await acompletion( messages=[{"role": "system", "content": formatted_prompt}], stream=False, **ll_raw ) - print(f"************ {response}") relevant = response["choices"][0]["message"]["content"] logger.info("Grading completed. Relevant: %s", relevant) if relevant not in ("yes", "no"): diff --git a/src/server/api/core/oci.py b/src/server/api/core/oci.py index 07b863f6..86b55542 100644 --- a/src/server/api/core/oci.py +++ b/src/server/api/core/oci.py @@ -31,7 +31,6 @@ def get_oci( raise ValueError("provide either 'client' or 'auth_profile', not both") oci_objects = bootstrap.OCI_OBJECTS - print(f"********** {oci_objects}") if client is not None: client_settings = settings.get_client_settings(client) derived_auth_profile = ( diff --git a/src/server/api/utils/oci.py b/src/server/api/utils/oci.py index 8071472f..662b550e 100644 --- a/src/server/api/utils/oci.py +++ b/src/server/api/utils/oci.py @@ -194,8 +194,8 @@ def get_genai_models(config: OracleCloudSettings, regional: bool = False) -> lis "id": model.id, } ) - except oci.exceptions.ServiceError: - logger.info("Region: %s has no GenAI services", region["region_name"]) + except oci.exceptions.ServiceError as ex: + logger.info("Unable to get GenAI Models in Region: %s (%s)", region["region_name"], ex.message) except (oci.exceptions.RequestException, urllib3.exceptions.MaxRetryError): logger.error("Timeout: Error querying GenAI services in %s", region["region_name"]) From 84afaa275dce25c14d566352698e9ad8f847d994 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 12:07:51 +0100 Subject: [PATCH 26/31] Fix after merge --- src/common/schema.py | 1 + src/server/bootstrap/models.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/common/schema.py b/src/common/schema.py index 6e767903..40bde4be 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -34,6 +34,7 @@ "google_vertexai", "groq", "huggingface", + "meta-llama", "mistralai", "ollama", "openai", diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 2852a2b4..0d58690a 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -103,7 +103,7 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "type": "ll", "provider": "meta-llama", "api_key": "", - "url": os.environ.get("ON_PREM_VLLM_URL", default="http://gpu:8000/v1"), + "api_base": os.environ.get("ON_PREM_VLLM_URL", default="http://gpu:8000/v1"), "context_length": 131072, "temperature": 1.0, "max_completion_tokens": 2048, From 90a6130ae36706644928cee730a6d0af995c55a2 Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 2 Sep 2025 16:00:07 +0200 Subject: [PATCH 27/31] fix vllm support chat, embeddings, split&embed --- src/client/mcp/rag/optimizer_utils/config.py | 2 +- src/common/schema.py | 1 + src/server/api/utils/models.py | 14 ++++++++++++-- src/server/bootstrap/models.py | 10 +++++----- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/client/mcp/rag/optimizer_utils/config.py b/src/client/mcp/rag/optimizer_utils/config.py index b7b16efe..f8e8dd13 100644 --- a/src/client/mcp/rag/optimizer_utils/config.py +++ b/src/client/mcp/rag/optimizer_utils/config.py @@ -66,7 +66,7 @@ def get_embeddings(data): elif (provider == "openai"): embeddings = OpenAIEmbeddings(model=model, api_key=api_key) logger.info("OpenAI embeddings connection successful") - elif (provider == "openai_compatible"): + elif (provider == "hosted_vllm"): embeddings = OpenAIEmbeddings(model=model, api_key=api_key,base_url=url,check_embedding_ctx_length=False) logger.info("OpenAI compatible embeddings connection successful") diff --git a/src/common/schema.py b/src/common/schema.py index 40bde4be..cc273e89 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -40,6 +40,7 @@ "openai", "perplexity", "xai", + "hosted_vllm" ] diff --git a/src/server/api/utils/models.py b/src/server/api/utils/models.py index d3756f4d..ae2dd93d 100644 --- a/src/server/api/utils/models.py +++ b/src/server/api/utils/models.py @@ -157,11 +157,21 @@ def get_client_embed(model_config: dict, oci_config: schema.OracleCloudSettings) compartment_id=oci_config.genai_compartment_id, ) else: - kwargs = { + if provider == "hosted_vllm": + kwargs = { + "provider": "openai", + "model": full_model_config["id"], + "base_url": full_model_config.get("api_base"), + "check_embedding_ctx_length":False #To avoid Tiktoken pre-transform on not OpenAI provided server + } + else: + kwargs = { "provider": provider, "model": full_model_config["id"], "base_url": full_model_config.get("api_base"), - } + } + + if full_model_config.get("api_key"): # only add if set kwargs["api_key"] = full_model_config["api_key"] client = init_embeddings(**kwargs) diff --git a/src/server/bootstrap/models.py b/src/server/bootstrap/models.py index 0d58690a..d6e845f2 100644 --- a/src/server/bootstrap/models.py +++ b/src/server/bootstrap/models.py @@ -98,12 +98,12 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "frequency_penalty": 0.0, }, { - "id": "Llama-3.2-1B-Instruct", + "id": "meta-llama/Llama-3.2-1B-Instruct", "enabled": os.getenv("ON_PREM_VLLM_URL") is not None, "type": "ll", - "provider": "meta-llama", + "provider": "hosted_vllm", "api_key": "", - "api_base": os.environ.get("ON_PREM_VLLM_URL", default="http://gpu:8000/v1"), + "api_base": os.environ.get("ON_PREM_VLLM_URL", default="http://localhost:8000/v1"), "context_length": 131072, "temperature": 1.0, "max_completion_tokens": 2048, @@ -153,8 +153,8 @@ def update_env_var(model: Model, provider: str, model_key: str, env_var: str): "id": "nomic-ai/nomic-embed-text-v1", "enabled": False, "type": "embed", - "provider": "huggingface", - "api_base": "http://localhost:1234/v1", + "provider": "hosted_vllm", + "api_base": "http://localhost:8001/v1", "api_key": "", "max_chunk_size": 8192, }, From 6cd5f67ec99ff0ad458bc12e18788179ba11eebd Mon Sep 17 00:00:00 2001 From: corradodebari Date: Tue, 2 Sep 2025 17:47:53 +0200 Subject: [PATCH 28/31] LangChain MCP template fix --- src/client/content/config/tabs/settings.py | 6 +++--- src/client/mcp/rag/optimizer_utils/config.py | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/client/content/config/tabs/settings.py b/src/client/content/config/tabs/settings.py index 688e0de7..dffdca4f 100644 --- a/src/client/content/config/tabs/settings.py +++ b/src/client/content/config/tabs/settings.py @@ -160,8 +160,8 @@ def spring_ai_conf_check(ll_model: dict, embed_model: dict) -> str: ll_provider = ll_model.get("provider", "") embed_provider = embed_model.get("provider", "") logger.info(f"llm chat:{ll_provider} - embeddings:{embed_provider}") - if all("openai_compatible" in p for p in (ll_provider, embed_provider)): - return "openai_compatible" + if all("hosted_vllm" in p for p in (ll_provider, embed_provider)): + return "hosted_vllm" if all("openai" in p for p in (ll_provider, embed_provider)): return "openai" if all("ollama" in p for p in (ll_provider, embed_provider)): @@ -364,7 +364,7 @@ def display_settings(): disabled=spring_ai_conf == "hybrid", ) with col_centre: - if (spring_ai_conf != "openai_compatible"): + if (spring_ai_conf != "hosted_vllm"): st.download_button( label="Download SpringAI", data=spring_ai_zip(spring_ai_conf, ll_config, embed_config), # Generate zip on the fly diff --git a/src/client/mcp/rag/optimizer_utils/config.py b/src/client/mcp/rag/optimizer_utils/config.py index f8e8dd13..df9058ac 100644 --- a/src/client/mcp/rag/optimizer_utils/config.py +++ b/src/client/mcp/rag/optimizer_utils/config.py @@ -25,14 +25,16 @@ def get_llm(data): logger.info("llm data:") logger.info(data["client_settings"]["ll_model"]["model"]) + model_full = data["client_settings"]["ll_model"]["model"] + _, prefix, model = model_full.partition('/') llm = {} models_by_id = {m["id"]: m for m in data.get("model_configs", [])} - llm_config= models_by_id.get(data["client_settings"]["ll_model"]["model"]) + llm_config= models_by_id.get(model) logger.info(llm_config) provider = llm_config["provider"] url = llm_config["api_base"] - api_key = llm_config["api_key"] - model = data["client_settings"]["ll_model"]["model"] + api_key = llm_config["api_key"] + logger.info(f"CHAT_MODEL: {model} {provider} {url} {api_key}") if provider == "ollama": # Initialize the LLM @@ -41,21 +43,22 @@ def get_llm(data): elif provider == "openai": llm = ChatOpenAI(model=model, api_key=api_key) logger.info("OpenAI LLM created") - elif provider =="openai_compatible": + elif provider =="hosted_vllm": llm = ChatOpenAI(model=model, api_key=api_key,base_url=url) - logger.info("OpenAI compatible LLM created") + logger.info("hosted_vllm compatible LLM created") return llm def get_embeddings(data): embeddings = {} logger.info("getting embeddings..") - model = data["client_settings"]["vector_search"]["model"] + model_full = data["client_settings"]["vector_search"]["model"] + _, prefix, model = model_full.partition('/') logger.info(f"embedding model: {model}") models_by_id = {m["id"]: m for m in data.get("model_configs", [])} model_params= models_by_id.get(model) provider = model_params["provider"] - url = model_params["url"] + url = model_params["api_base"] api_key = model_params["api_key"] logger.info(f"Embeddings Model: {model} {provider} {url} {api_key}") @@ -68,7 +71,7 @@ def get_embeddings(data): logger.info("OpenAI embeddings connection successful") elif (provider == "hosted_vllm"): embeddings = OpenAIEmbeddings(model=model, api_key=api_key,base_url=url,check_embedding_ctx_length=False) - logger.info("OpenAI compatible embeddings connection successful") + logger.info("hosted_vllm compatible embeddings connection successful") return embeddings From d29b70108f6e2c823ffdb81c42fcea7b1bd728a5 Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 12:44:36 +0100 Subject: [PATCH 29/31] API Server Control --- opentofu/modules/vm/templates/cloudinit-compute.tpl | 1 + src/client/content/api_server.py | 8 ++++---- src/launch_client.py | 8 ++++---- src/launch_server.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/opentofu/modules/vm/templates/cloudinit-compute.tpl b/opentofu/modules/vm/templates/cloudinit-compute.tpl index f0e95545..5e4a983b 100644 --- a/opentofu/modules/vm/templates/cloudinit-compute.tpl +++ b/opentofu/modules/vm/templates/cloudinit-compute.tpl @@ -87,6 +87,7 @@ write_files: content: | #!/bin/bash export OCI_CLI_AUTH=instance_principal + export API_SERVER_CONTROL="True" export DB_USERNAME='AI_OPTIMIZER' export DB_PASSWORD='${db_password}' export DB_DSN='${db_name}_TP' diff --git a/src/client/content/api_server.py b/src/client/content/api_server.py index ddede36b..76b0fe40 100644 --- a/src/client/content/api_server.py +++ b/src/client/content/api_server.py @@ -50,7 +50,7 @@ def server_restart() -> None: state.server["key"] = os.getenv("API_SERVER_KEY") launch_server.stop_server(state.server["pid"]) - _, state.server["pid"] = launch_server.start_server(state.server["port"]) + state.server["pid"] = launch_server.start_server(state.server["port"]) time.sleep(10) state.pop("server_client", None) @@ -69,16 +69,16 @@ async def main() -> None: key="user_server_port", min_value=1, max_value=65535, - disabled=state.server["remote"], + disabled=not state.server["control"], ) right.text_input( "API Server Key:", value=state.server["key"], key="user_server_key", type="password", - disabled=state.server["remote"], + disabled=not state.server["control"], ) - if not state.server["remote"]: + if state.server["control"]: st.button("Restart Server", type="primary", on_click=server_restart) st.header("Server Settings", divider="red") diff --git a/src/launch_client.py b/src/launch_client.py index f50406dc..d4051512 100644 --- a/src/launch_client.py +++ b/src/launch_client.py @@ -30,6 +30,7 @@ logger.debug("Imported API Server.") except ImportError as ex: logger.debug("API Server not present: %s", ex) + os.environ.pop("API_SERVER_CONTROL", None) LAUNCH_SERVER_EXISTS = False @@ -40,10 +41,11 @@ def init_server_state() -> None: """initialize Streamlit State server""" if "server" not in state: logger.info("Initializing state.server") + api_server_control: bool = os.getenv("API_SERVER_CONTROL") is not None state.server = {"url": os.getenv("API_SERVER_URL", "http://localhost")} state.server["port"] = int(os.getenv("API_SERVER_PORT", "8000")) state.server["key"] = os.getenv("API_SERVER_KEY") - state.server["remote"] = True + state.server["control"] = api_server_control logger.debug("Server State: %s", state.server) @@ -158,7 +160,5 @@ def main() -> None: try: logger.debug("Server PID: %i", state.server["pid"]) except KeyError: - server_state, pid = start_server(logfile=True) - state.server["pid"] = pid - state.server["remote"] = server_state != "started" + state.server["pid"] = start_server(logfile=True) main() diff --git a/src/launch_server.py b/src/launch_server.py index 7727f834..fde83bda 100644 --- a/src/launch_server.py +++ b/src/launch_server.py @@ -48,7 +48,7 @@ ########################################## # Process Control ########################################## -def start_server(port: int = 8000, logfile: bool = False) -> tuple[str, int]: +def start_server(port: int = 8000, logfile: bool = False) -> int: """Start the uvicorn server for FastAPI""" logger.info("Starting Oracle AI Optimizer and Toolkit") @@ -96,7 +96,7 @@ def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: existing_pid = get_pid_using_port(port) if existing_pid: logger.info("API server already running on port: %i (PID: %i)", port, existing_pid) - return ("existing", existing_pid) + return existing_pid popen_queue = queue.Queue() thread = threading.Thread( @@ -105,7 +105,7 @@ def start_subprocess(port: int, logfile: bool) -> subprocess.Popen: ) thread.start() - return ("started", popen_queue.get().pid) + return popen_queue.get().pid def stop_server(pid: int) -> None: From 488af6b90d25ea41b8b2a4c9942888af1ad4d26e Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 17:01:54 +0100 Subject: [PATCH 30/31] add hosted_vllm --- src/common/schema.py | 71 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 58 insertions(+), 13 deletions(-) diff --git a/src/common/schema.py b/src/common/schema.py index cc273e89..4480a7b3 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -2,8 +2,7 @@ Copyright (c) 2024, 2025, Oracle and/or its affiliates. Licensed under the Universal Permissive License v1.0 as shown at http://oss.oracle.com/licenses/upl. """ -# spell-checker:ignore ollama hnsw mult ocid testset selectai explainsql showsql vector_search aioptimizer genai -# spell-checker:ignore deepseek groq huggingface mistralai ocigenai vertexai +# spell-checker:ignore hnsw ocid aioptimizer explainsql genai mult ollama selectai showsql import time from typing import Optional, Literal, get_args, Any @@ -20,28 +19,74 @@ IndexTypes = Literal["HNSW", "IVF"] # Model Providers +# spell-checker:disable ModelProviders = Literal[ - "oci", + "ai21", + "aiohttp_openai", "anthropic", + "azure", "azure_ai", - "azure_openai", + "base_llm", "bedrock", - "bedrock_converse", - "cohere", + "baseten", + "bytez", + "cloudflare", + "clarifai", + "codestral", + "databricks", + "datarobot", + "deepgram", + "deepinfra", "deepseek", - "google_anthropic_vertex", - "google_genai", - "google_vertexai", + "empower", + "fireworks_ai", + "featherless_ai", + "galadriel", + "gemini", + "github", + "github_copilot", "groq", + "hosted_vllm", "huggingface", - "meta-llama", - "mistralai", - "ollama", + "hyperbolic", + "infinity", + "jina_ai", + "lambda_ai", + "litellm_proxy", + "lm_studio", + "llamafile", + "maritalk", + "meta_llama", + "moonshot", + "mistral", + "morph", + "nebius", + "nlp_cloud", + "nscale", + "nvidia_nim", + "novita", + "oci", "openai", + "openrouter", + "ollama", "perplexity", + "petals", + "pg_vector", + "predibase", + "recraft", + "sambanova", + "sagemaker", + "snowflake", + "together_ai", + "topaz", + "triton", + "vertex_ai", + "vllm", + "voyage", + "watsonx", "xai", - "hosted_vllm" ] +# spell-checker:enable ##################################################### From 87e76b1d34c0059bce25b52f0bbc79eac23464ec Mon Sep 17 00:00:00 2001 From: gotsysdba Date: Tue, 2 Sep 2025 17:19:15 +0100 Subject: [PATCH 31/31] Fix Tests --- src/client/content/config/tabs/oci.py | 5 +++-- src/common/schema.py | 32 ++++++++++++++++++--------- tests/conftest.py | 1 + 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/client/content/config/tabs/oci.py b/src/client/content/config/tabs/oci.py index e4543da7..20db63c2 100644 --- a/src/client/content/config/tabs/oci.py +++ b/src/client/content/config/tabs/oci.py @@ -98,12 +98,13 @@ def display_oci() -> None: disable_config = False oci_lookup = st_common.state_configs_lookup("oci_configs", "auth_profile") # Handle instance_principal and oke_workload_identity - if len(oci_lookup) == 1 and state.oci_configs[0]["authentication"] in ( + oci_auth = state.oci_configs[0].get("authentication") + if len(oci_lookup) == 1 and oci_auth in ( "instance_principal", "oke_workload_identity", ): st.info("Using OCI Authentication Principals", icon="â„šī¸") - supplied["authentication"] = state.oci_configs[0]["authentication"] + supplied["authentication"] = oci_auth supplied["tenancy"] = state.oci_configs[0]["tenancy"] disable_config = True if len(oci_lookup) > 0: diff --git a/src/common/schema.py b/src/common/schema.py index 4480a7b3..5cc42fcf 100644 --- a/src/common/schema.py +++ b/src/common/schema.py @@ -22,29 +22,38 @@ # spell-checker:disable ModelProviders = Literal[ "ai21", + "aiml", "aiohttp_openai", "anthropic", "azure", "azure_ai", "base_llm", - "bedrock", + "base.py", "baseten", + "bedrock", "bytez", - "cloudflare", + "cerebras", "clarifai", + "cloudflare", "codestral", + "cohere", + "cometapi", + "dashscope", "databricks", "datarobot", "deepgram", "deepinfra", "deepseek", + "elevenlabs", "empower", - "fireworks_ai", "featherless_ai", + "fireworks_ai", + "friendliai", "galadriel", "gemini", "github", "github_copilot", + "gradient_ai", "groq", "hosted_vllm", "huggingface", @@ -53,38 +62,41 @@ "jina_ai", "lambda_ai", "litellm_proxy", - "lm_studio", "llamafile", - "maritalk", + "lm_studio", "meta_llama", - "moonshot", "mistral", + "moonshot", "morph", "nebius", "nlp_cloud", + "novita", "nscale", "nvidia_nim", - "novita", "oci", + "ollama", + "oobabooga", "openai", "openrouter", - "ollama", "perplexity", "petals", "pg_vector", "predibase", "recraft", - "sambanova", + "replicate", "sagemaker", + "sambanova", "snowflake", "together_ai", "topaz", "triton", - "vertex_ai", + "v0", + "vercel_ai_gateway,vertex_ai", "vllm", "voyage", "watsonx", "xai", + "xinference", ] # spell-checker:enable diff --git a/tests/conftest.py b/tests/conftest.py index 3897c4a7..3c8b94f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -137,6 +137,7 @@ def _app_test(page): "key": os.environ.get("API_SERVER_KEY"), "url": os.environ.get("API_SERVER_URL"), "port": int(os.environ.get("API_SERVER_PORT")), + "control": True } response = requests.get( url=f"{at.session_state.server['url']}:{at.session_state.server['port']}/v1/settings",