From 19dae1d870e04cc0dfb157e311163a998176c7b1 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 20 Sep 2024 12:00:03 -0700 Subject: [PATCH] Wrote tests for the chat apis (#2525) * Wrote tests for the chat apis * slight changes to the case --- .../common_utils/managers/document.py | 42 +++- .../connector/test_connector_deletion.py | 8 +- .../tests/dev_apis/test_knowledge_chat.py | 188 ++++++++++++++++++ .../tests/dev_apis/test_simple_chat_api.py | 87 +++++++- .../tests/document_set/test_syncing.py | 6 +- .../tests/usergroup/test_usergroup_syncing.py | 4 +- 6 files changed, 319 insertions(+), 16 deletions(-) create mode 100644 backend/tests/integration/tests/dev_apis/test_knowledge_chat.py diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py index dcd8def5c..28234a9be 100644 --- a/backend/tests/integration/common_utils/managers/document.py +++ b/backend/tests/integration/common_utils/managers/document.py @@ -55,13 +55,18 @@ def _verify_document_permissions( ) -def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict: +def _generate_dummy_document( + document_id: str, + cc_pair_id: int, + content: str | None = None, +) -> dict: + text = content if content else f"This is test document {document_id}" return { "document": { "id": document_id, "sections": [ { - "text": f"This is test document {document_id}", + "text": text, "link": f"{document_id}", } ], @@ -77,12 +82,12 @@ def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict: class DocumentManager: @staticmethod - def seed_and_attach_docs( + def seed_dummy_docs( cc_pair: TestCCPair, num_docs: int = NUM_DOCS, document_ids: list[str] | None = None, api_key: TestAPIKey | None = None, - ) -> TestCCPair: + ) -> list[SimpleTestDocument]: # Use provided document_ids if available, otherwise generate random UUIDs if document_ids is None: document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)] @@ -101,14 +106,39 @@ class DocumentManager: response.raise_for_status() print("Seeding completed successfully.") - cc_pair.documents = [ + return [ SimpleTestDocument( id=document["document"]["id"], content=document["document"]["sections"][0]["text"], ) for document in documents ] - return cc_pair + + @staticmethod + def seed_doc_with_content( + cc_pair: TestCCPair, + content: str, + document_id: str | None = None, + api_key: TestAPIKey | None = None, + ) -> SimpleTestDocument: + # Use provided document_ids if available, otherwise generate random UUIDs + if document_id is None: + document_id = f"test-doc-{uuid4()}" + # Create and ingest some documents + document: dict = _generate_dummy_document(document_id, cc_pair.id, content) + response = requests.post( + f"{API_SERVER_URL}/danswer-api/ingestion", + json=document, + headers=api_key.headers if api_key else GENERAL_HEADERS, + ) + response.raise_for_status() + + print("Seeding completed successfully.") + + return SimpleTestDocument( + id=document["document"]["id"], + content=document["document"]["sections"][0]["text"], + ) @staticmethod def verify( diff --git a/backend/tests/integration/tests/connector/test_connector_deletion.py b/backend/tests/integration/tests/connector/test_connector_deletion.py index f0a83034b..f2d1e6910 100644 --- a/backend/tests/integration/tests/connector/test_connector_deletion.py +++ b/backend/tests/integration/tests/connector/test_connector_deletion.py @@ -47,12 +47,12 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: ) # seed documents - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, ) - cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair_2.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_2, num_docs=NUM_DOCS, api_key=api_key, @@ -197,12 +197,12 @@ def test_connector_deletion_for_overlapping_connectors( ) doc_ids = [str(uuid4())] - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, document_ids=doc_ids, api_key=api_key, ) - cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair_2.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_2, document_ids=doc_ids, api_key=api_key, diff --git a/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py b/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py new file mode 100644 index 000000000..7cb7a3199 --- /dev/null +++ b/backend/tests/integration/tests/dev_apis/test_knowledge_chat.py @@ -0,0 +1,188 @@ +import requests + +from danswer.configs.constants import MessageType +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser + + +def test_all_stream_chat_message_objects_outputs(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # create connector + cc_pair_1: TestCCPair = CCPairManager.create_from_scratch( + user_performing_action=admin_user, + ) + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + LLMProviderManager.create(user_performing_action=admin_user) + + # SEEDING DOCUMENTS + cc_pair_1.documents = [] + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Pablo's favorite color is blue", + api_key=api_key, + ) + ) + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Chris's favorite color is red", + api_key=api_key, + ) + ) + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Pika's favorite color is green", + api_key=api_key, + ) + ) + + # TESTING RESPONSE FOR QUESTION 1 + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "What is Pablo's favorite color?", + "role": MessageType.USER.value, + } + ], + "persona_id": 0, + "prompt_id": 0, + }, + headers=admin_user.headers, + ) + assert response.status_code == 200 + response_json = response.json() + + # check that the answer is correct + answer_1 = response_json["answer"] + assert "blue" in answer_1.lower() + + # check that the llm selected a document + assert 0 in response_json["llm_selected_doc_indices"] + + # check that the final context documents are correct + # (it should contain all documents because there arent enough to exclude any) + assert 0 in response_json["final_context_doc_indices"] + assert 1 in response_json["final_context_doc_indices"] + assert 2 in response_json["final_context_doc_indices"] + + # check that the cited documents are correct + assert cc_pair_1.documents[0].id in response_json["cited_documents"].values() + + # check that the top documents are correct + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id + print("response 1/3 passed") + + # TESTING RESPONSE FOR QUESTION 2 + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "What is Pablo's favorite color?", + "role": MessageType.USER.value, + }, + { + "message": answer_1, + "role": MessageType.ASSISTANT.value, + }, + { + "message": "What is Chris's favorite color?", + "role": MessageType.USER.value, + }, + ], + "persona_id": 0, + "prompt_id": 0, + }, + headers=admin_user.headers, + ) + assert response.status_code == 200 + response_json = response.json() + + # check that the answer is correct + answer_2 = response_json["answer"] + assert "red" in answer_2.lower() + + # check that the llm selected a document + assert 0 in response_json["llm_selected_doc_indices"] + + # check that the final context documents are correct + # (it should contain all documents because there arent enough to exclude any) + assert 0 in response_json["final_context_doc_indices"] + assert 1 in response_json["final_context_doc_indices"] + assert 2 in response_json["final_context_doc_indices"] + + # check that the cited documents are correct + assert cc_pair_1.documents[1].id in response_json["cited_documents"].values() + + # check that the top documents are correct + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[1].id + print("response 2/3 passed") + + # TESTING RESPONSE FOR QUESTION 3 + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "What is Pablo's favorite color?", + "role": MessageType.USER.value, + }, + { + "message": answer_1, + "role": MessageType.ASSISTANT.value, + }, + { + "message": "What is Chris's favorite color?", + "role": MessageType.USER.value, + }, + { + "message": answer_2, + "role": MessageType.ASSISTANT.value, + }, + { + "message": "What is Pika's favorite color?", + "role": MessageType.USER.value, + }, + ], + "persona_id": 0, + "prompt_id": 0, + }, + headers=admin_user.headers, + ) + assert response.status_code == 200 + response_json = response.json() + + # check that the answer is correct + answer_3 = response_json["answer"] + assert "green" in answer_3.lower() + + # check that the llm selected a document + assert 0 in response_json["llm_selected_doc_indices"] + + # check that the final context documents are correct + # (it should contain all documents because there arent enough to exclude any) + assert 0 in response_json["final_context_doc_indices"] + assert 1 in response_json["final_context_doc_indices"] + assert 2 in response_json["final_context_doc_indices"] + + # check that the cited documents are correct + assert cc_pair_1.documents[2].id in response_json["cited_documents"].values() + + # check that the top documents are correct + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id + print("response 3/3 passed") diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index d4edcc583..b712bf528 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -25,7 +25,7 @@ def test_send_message_simple_with_history(reset: None) -> None: user_performing_action=admin_user, ) LLMProviderManager.create(user_performing_action=admin_user) - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, @@ -60,3 +60,88 @@ def test_send_message_simple_with_history(reset: None) -> None: ) assert found_doc assert found_doc["metadata"]["document_id"] == doc.id + + +def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # create connector + cc_pair_1: TestCCPair = CCPairManager.create_from_scratch( + user_performing_action=admin_user, + ) + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + LLMProviderManager.create(user_performing_action=admin_user) + + # SEEDING DOCUMENTS + cc_pair_1.documents = [] + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Chris's favorite color is blue", + api_key=api_key, + ) + ) + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Hagen's favorite color is red", + api_key=api_key, + ) + ) + cc_pair_1.documents.append( + DocumentManager.seed_doc_with_content( + cc_pair=cc_pair_1, + content="Pablo's favorite color is green", + api_key=api_key, + ) + ) + + # SEINDING MESSAGE 1 + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "What is Pablo's favorite color?", + "role": MessageType.USER.value, + } + ], + "persona_id": 0, + "prompt_id": 0, + }, + headers=admin_user.headers, + ) + assert response.status_code == 200 + response_json = response.json() + # get the db_doc_id of the top document to use as a search doc id for second message + first_db_doc_id = response_json["top_documents"][0]["db_doc_id"] + + # SEINDING MESSAGE 2 + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "What is Pablo's favorite color?", + "role": MessageType.USER.value, + } + ], + "persona_id": 0, + "prompt_id": 0, + "search_doc_ids": [first_db_doc_id], + }, + headers=admin_user.headers, + ) + assert response.status_code == 200 + response_json = response.json() + + # since we only gave it one search doc, all responses should only contain that doc + assert response_json["final_context_doc_indices"] == [0] + assert response_json["llm_selected_doc_indices"] == [0] + assert cc_pair_1.documents[2].id in response_json["cited_documents"].values() + # This ensures the the document we think we are referencing when we send the search_doc_ids in the second + # message is the document that we expect it to be + assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index 217d106af..95425b862 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -28,7 +28,7 @@ def test_multiple_document_sets_syncing_same_connnector( ) # seed documents - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, @@ -86,13 +86,13 @@ def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None: ) # seed documents - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, ) - cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair_2.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_2, num_docs=NUM_DOCS, api_key=api_key, diff --git a/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py index fbb976f9f..a56975950 100644 --- a/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py +++ b/backend/tests/integration/tests/usergroup/test_usergroup_syncing.py @@ -31,13 +31,13 @@ def test_removing_connector(reset: None, vespa_client: TestVespaClient) -> None: ) # seed documents - cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, ) - cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair_2.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_2, num_docs=NUM_DOCS, api_key=api_key,