-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into PSL-US-7770-UnitTest
- Loading branch information
Showing
14 changed files
with
1,873 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# db.py | ||
import os | ||
|
||
import pymssql | ||
from dotenv import load_dotenv | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import base64 | ||
import json | ||
from unittest.mock import patch | ||
|
||
from backend.auth.auth_utils import (get_authenticated_user_details, | ||
get_tenantid) | ||
|
||
|
||
def test_get_authenticated_user_details_no_principal_id(): | ||
request_headers = {} | ||
sample_user_data = { | ||
"X-Ms-Client-Principal-Id": "default-id", | ||
"X-Ms-Client-Principal-Name": "default-name", | ||
"X-Ms-Client-Principal-Idp": "default-idp", | ||
"X-Ms-Token-Aad-Id-Token": "default-token", | ||
"X-Ms-Client-Principal": "default-b64", | ||
} | ||
with patch("backend.auth.sample_user.sample_user", sample_user_data): | ||
user_details = get_authenticated_user_details(request_headers) | ||
assert user_details["user_principal_id"] == "default-id" | ||
assert user_details["user_name"] == "default-name" | ||
assert user_details["auth_provider"] == "default-idp" | ||
assert user_details["auth_token"] == "default-token" | ||
assert user_details["client_principal_b64"] == "default-b64" | ||
|
||
|
||
def test_get_authenticated_user_details_with_principal_id(): | ||
request_headers = { | ||
"X-Ms-Client-Principal-Id": "test-id", | ||
"X-Ms-Client-Principal-Name": "test-name", | ||
"X-Ms-Client-Principal-Idp": "test-idp", | ||
"X-Ms-Token-Aad-Id-Token": "test-token", | ||
"X-Ms-Client-Principal": "test-b64", | ||
} | ||
user_details = get_authenticated_user_details(request_headers) | ||
assert user_details["user_principal_id"] == "test-id" | ||
assert user_details["user_name"] == "test-name" | ||
assert user_details["auth_provider"] == "test-idp" | ||
assert user_details["auth_token"] == "test-token" | ||
assert user_details["client_principal_b64"] == "test-b64" | ||
|
||
|
||
def test_get_tenantid_valid_b64(): | ||
user_info = {"tid": "test-tenant-id"} | ||
client_principal_b64 = base64.b64encode( | ||
json.dumps(user_info).encode("utf-8") | ||
).decode("utf-8") | ||
tenant_id = get_tenantid(client_principal_b64) | ||
assert tenant_id == "test-tenant-id" | ||
|
||
|
||
def test_get_tenantid_invalid_b64(): | ||
client_principal_b64 = "invalid-b64" | ||
with patch("backend.auth.auth_utils.logging") as mock_logging: | ||
tenant_id = get_tenantid(client_principal_b64) | ||
assert tenant_id == "" | ||
mock_logging.exception.assert_called_once() | ||
|
||
|
||
def test_get_tenantid_no_tid(): | ||
user_info = {"some_other_key": "value"} | ||
client_principal_b64 = base64.b64encode( | ||
json.dumps(user_info).encode("utf-8") | ||
).decode("utf-8") | ||
tenant_id = get_tenantid(client_principal_b64) | ||
assert tenant_id is None |
184 changes: 184 additions & 0 deletions
184
ClientAdvisor/App/tests/backend/history/test_cosmosdb_service.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
from unittest.mock import AsyncMock, MagicMock, patch | ||
|
||
import pytest | ||
from azure.cosmos import exceptions | ||
|
||
from backend.history.cosmosdbservice import CosmosConversationClient | ||
|
||
|
||
# Helper function to create an async iterable | ||
class AsyncIterator: | ||
def __init__(self, items): | ||
self.items = items | ||
self.index = 0 | ||
|
||
def __aiter__(self): | ||
return self | ||
|
||
async def __anext__(self): | ||
if self.index < len(self.items): | ||
item = self.items[self.index] | ||
self.index += 1 | ||
return item | ||
else: | ||
raise StopAsyncIteration | ||
|
||
|
||
@pytest.fixture | ||
def cosmos_client(): | ||
return CosmosConversationClient( | ||
cosmosdb_endpoint="https://fake.endpoint", | ||
credential="fake_credential", | ||
database_name="test_db", | ||
container_name="test_container", | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_init_invalid_credentials(): | ||
with patch( | ||
"azure.cosmos.aio.CosmosClient.__init__", | ||
side_effect=exceptions.CosmosHttpResponseError( | ||
status_code=401, message="Unauthorized" | ||
), | ||
): | ||
with pytest.raises(ValueError, match="Invalid credentials"): | ||
CosmosConversationClient( | ||
cosmosdb_endpoint="https://fake.endpoint", | ||
credential="fake_credential", | ||
database_name="test_db", | ||
container_name="test_container", | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_init_invalid_endpoint(): | ||
with patch( | ||
"azure.cosmos.aio.CosmosClient.__init__", | ||
side_effect=exceptions.CosmosHttpResponseError( | ||
status_code=404, message="Not Found" | ||
), | ||
): | ||
with pytest.raises(ValueError, match="Invalid CosmosDB endpoint"): | ||
CosmosConversationClient( | ||
cosmosdb_endpoint="https://fake.endpoint", | ||
credential="fake_credential", | ||
database_name="test_db", | ||
container_name="test_container", | ||
) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ensure_success(cosmos_client): | ||
cosmos_client.database_client.read = AsyncMock() | ||
cosmos_client.container_client.read = AsyncMock() | ||
success, message = await cosmos_client.ensure() | ||
assert success | ||
assert message == "CosmosDB client initialized successfully" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_ensure_failure(cosmos_client): | ||
cosmos_client.database_client.read = AsyncMock(side_effect=Exception) | ||
success, message = await cosmos_client.ensure() | ||
assert not success | ||
assert "CosmosDB database" in message | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_conversation(cosmos_client): | ||
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "123"}) | ||
response = await cosmos_client.create_conversation("user_1", "Test Conversation") | ||
assert response["id"] == "123" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_conversation_failure(cosmos_client): | ||
cosmos_client.container_client.upsert_item = AsyncMock(return_value=None) | ||
response = await cosmos_client.create_conversation("user_1", "Test Conversation") | ||
assert not response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_upsert_conversation(cosmos_client): | ||
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "123"}) | ||
response = await cosmos_client.upsert_conversation({"id": "123"}) | ||
assert response["id"] == "123" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_delete_conversation(cosmos_client): | ||
cosmos_client.container_client.read_item = AsyncMock(return_value={"id": "123"}) | ||
cosmos_client.container_client.delete_item = AsyncMock(return_value=True) | ||
response = await cosmos_client.delete_conversation("user_1", "123") | ||
assert response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_delete_conversation_not_found(cosmos_client): | ||
cosmos_client.container_client.read_item = AsyncMock(return_value=None) | ||
response = await cosmos_client.delete_conversation("user_1", "123") | ||
assert response | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_delete_messages(cosmos_client): | ||
cosmos_client.get_messages = AsyncMock( | ||
return_value=[{"id": "msg_1"}, {"id": "msg_2"}] | ||
) | ||
cosmos_client.container_client.delete_item = AsyncMock(return_value=True) | ||
response = await cosmos_client.delete_messages("conv_1", "user_1") | ||
assert len(response) == 2 | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_conversations(cosmos_client): | ||
items = [{"id": "conv_1"}, {"id": "conv_2"}] | ||
cosmos_client.container_client.query_items = MagicMock( | ||
return_value=AsyncIterator(items) | ||
) | ||
response = await cosmos_client.get_conversations("user_1", 10) | ||
assert len(response) == 2 | ||
assert response[0]["id"] == "conv_1" | ||
assert response[1]["id"] == "conv_2" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_conversation(cosmos_client): | ||
items = [{"id": "conv_1"}] | ||
cosmos_client.container_client.query_items = MagicMock( | ||
return_value=AsyncIterator(items) | ||
) | ||
response = await cosmos_client.get_conversation("user_1", "conv_1") | ||
assert response["id"] == "conv_1" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_create_message(cosmos_client): | ||
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "msg_1"}) | ||
cosmos_client.get_conversation = AsyncMock(return_value={"id": "conv_1"}) | ||
cosmos_client.upsert_conversation = AsyncMock() | ||
response = await cosmos_client.create_message( | ||
"msg_1", "conv_1", "user_1", {"role": "user", "content": "Hello"} | ||
) | ||
assert response["id"] == "msg_1" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_update_message_feedback(cosmos_client): | ||
cosmos_client.container_client.read_item = AsyncMock(return_value={"id": "msg_1"}) | ||
cosmos_client.container_client.upsert_item = AsyncMock(return_value={"id": "msg_1"}) | ||
response = await cosmos_client.update_message_feedback( | ||
"user_1", "msg_1", "positive" | ||
) | ||
assert response["id"] == "msg_1" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_messages(cosmos_client): | ||
items = [{"id": "msg_1"}, {"id": "msg_2"}] | ||
cosmos_client.container_client.query_items = MagicMock( | ||
return_value=AsyncIterator(items) | ||
) | ||
response = await cosmos_client.get_messages("user_1", "conv_1") | ||
assert len(response) == 2 |
Oops, something went wrong.