From e2c37d6847fbe2f67b609353b32800980dc06481 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sun, 15 Sep 2024 12:40:48 -0700 Subject: [PATCH] Test stream + Update Copy (#2317) * update copy + conditional ordering * answer stream checks * update * add basic tests for chat streams * slightly simplify * fix typing * quick typing updates + nits --- .../integration/common_utils/managers/chat.py | 160 ++++++++++++++++++ .../integration/common_utils/test_models.py | 25 +++ .../streaming_endpoints/test_answer_stream.py | 25 +++ .../streaming_endpoints/test_chat_stream.py | 19 +++ .../app/admin/configuration/search/page.tsx | 23 +-- web/src/app/admin/embeddings/interfaces.ts | 29 +++- .../components/embedding/ModelSelector.tsx | 24 +-- 7 files changed, 277 insertions(+), 28 deletions(-) create mode 100644 backend/tests/integration/common_utils/managers/chat.py create mode 100644 backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py create mode 100644 backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py diff --git a/backend/tests/integration/common_utils/managers/chat.py b/backend/tests/integration/common_utils/managers/chat.py new file mode 100644 index 000000000..3d6281764 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/chat.py @@ -0,0 +1,160 @@ +import json + +import requests +from requests.models import Response + +from danswer.file_store.models import FileDescriptor +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride +from danswer.one_shot_answer.models import DirectQARequest +from danswer.one_shot_answer.models import ThreadMessage +from danswer.search.models import RetrievalDetails +from danswer.server.query_and_chat.models import ChatSessionCreationRequest +from danswer.server.query_and_chat.models import CreateChatMessageRequest +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import StreamedResponse +from tests.integration.common_utils.test_models import TestChatMessage +from tests.integration.common_utils.test_models import TestChatSession +from tests.integration.common_utils.test_models import TestUser + + +class ChatSessionManager: + @staticmethod + def create( + persona_id: int = -1, + description: str = "Test chat session", + user_performing_action: TestUser | None = None, + ) -> TestChatSession: + chat_session_creation_req = ChatSessionCreationRequest( + persona_id=persona_id, description=description + ) + response = requests.post( + f"{API_SERVER_URL}/chat/create-chat-session", + json=chat_session_creation_req.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + chat_session_id = response.json()["chat_session_id"] + return TestChatSession( + id=chat_session_id, persona_id=persona_id, description=description + ) + + @staticmethod + def send_message( + chat_session_id: int, + message: str, + parent_message_id: int | None = None, + user_performing_action: TestUser | None = None, + file_descriptors: list[FileDescriptor] = [], + prompt_id: int | None = None, + search_doc_ids: list[int] | None = None, + retrieval_options: RetrievalDetails | None = None, + query_override: str | None = None, + regenerate: bool | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, + alternate_assistant_id: int | None = None, + use_existing_user_message: bool = False, + ) -> StreamedResponse: + chat_message_req = CreateChatMessageRequest( + chat_session_id=chat_session_id, + parent_message_id=parent_message_id, + message=message, + file_descriptors=file_descriptors or [], + prompt_id=prompt_id, + search_doc_ids=search_doc_ids or [], + retrieval_options=retrieval_options, + query_override=query_override, + regenerate=regenerate, + llm_override=llm_override, + prompt_override=prompt_override, + alternate_assistant_id=alternate_assistant_id, + use_existing_user_message=use_existing_user_message, + ) + + response = requests.post( + f"{API_SERVER_URL}/chat/send-message", + json=chat_message_req.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + stream=True, + ) + + return ChatSessionManager.analyze_response(response) + + @staticmethod + def get_answer_with_quote( + persona_id: int, + message: str, + user_performing_action: TestUser | None = None, + ) -> StreamedResponse: + direct_qa_request = DirectQARequest( + messages=[ThreadMessage(message=message)], + prompt_id=None, + persona_id=persona_id, + ) + + response = requests.post( + f"{API_SERVER_URL}/query/stream-answer-with-quote", + json=direct_qa_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + stream=True, + ) + response.raise_for_status() + + return ChatSessionManager.analyze_response(response) + + @staticmethod + def analyze_response(response: Response) -> StreamedResponse: + response_data = [ + json.loads(line.decode("utf-8")) for line in response.iter_lines() if line + ] + + analyzed = StreamedResponse() + + for data in response_data: + if "rephrased_query" in data: + analyzed.rephrased_query = data["rephrased_query"] + elif "tool_name" in data: + analyzed.tool_name = data["tool_name"] + analyzed.tool_result = ( + data.get("tool_result") + if analyzed.tool_name == "run_search" + else None + ) + elif "relevance_summaries" in data: + analyzed.relevance_summaries = data["relevance_summaries"] + elif "answer_piece" in data and data["answer_piece"]: + analyzed.full_message += data["answer_piece"] + + return analyzed + + @staticmethod + def get_chat_history( + chat_session: TestChatSession, + user_performing_action: TestUser | None = None, + ) -> list[TestChatMessage]: + response = requests.get( + f"{API_SERVER_URL}/chat/history/{chat_session.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + return [ + TestChatMessage( + id=msg["id"], + chat_session_id=chat_session.id, + parent_message_id=msg.get("parent_message_id"), + message=msg["message"], + response=msg.get("response", ""), + ) + for msg in response.json() + ] diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py index 04db0851e..2d8744327 100644 --- a/backend/tests/integration/common_utils/test_models.py +++ b/backend/tests/integration/common_utils/test_models.py @@ -118,3 +118,28 @@ class TestPersona(BaseModel): llm_model_version_override: str | None users: list[str] groups: list[int] + + +# +class TestChatSession(BaseModel): + id: int + persona_id: int + description: str + + +class TestChatMessage(BaseModel): + id: str | None = None + chat_session_id: int + parent_message_id: str | None + message: str + response: str + + +class StreamedResponse(BaseModel): + full_message: str = "" + rephrased_query: str | None = None + tool_name: str | None = None + top_documents: list[dict[str, Any]] | None = None + relevance_summaries: list[dict[str, Any]] | None = None + tool_result: Any | None = None + user: str | None = None diff --git a/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py new file mode 100644 index 000000000..1b8a4c790 --- /dev/null +++ b/backend/tests/integration/tests/streaming_endpoints/test_answer_stream.py @@ -0,0 +1,25 @@ +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestUser + + +def test_send_message_simple_with_history(reset: None) -> None: + admin_user: TestUser = UserManager.create(name="admin_user") + LLMProviderManager.create(user_performing_action=admin_user) + + test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) + + response = ChatSessionManager.get_answer_with_quote( + persona_id=test_chat_session.persona_id, + message="Hello, this is a test.", + user_performing_action=admin_user, + ) + + assert ( + response.tool_name is not None + ), "Tool name should be specified (always search)" + assert ( + response.relevance_summaries is not None + ), "Relevance summaries should be present for all search streams" + assert len(response.full_message) > 0, "Response message should not be empty" diff --git a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py new file mode 100644 index 000000000..4346e1848 --- /dev/null +++ b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py @@ -0,0 +1,19 @@ +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.chat import ChatSessionManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestUser + + +def test_send_message_simple_with_history(reset: None) -> None: + admin_user: TestUser = UserManager.create(name="admin_user") + LLMProviderManager.create(user_performing_action=admin_user) + + test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) + + response = ChatSessionManager.send_message( + chat_session_id=test_chat_session.id, + message="this is a test message", + user_performing_action=admin_user, + ) + + assert len(response.full_message) > 0 diff --git a/web/src/app/admin/configuration/search/page.tsx b/web/src/app/admin/configuration/search/page.tsx index b2abebae7..6f76eeb7e 100644 --- a/web/src/app/admin/configuration/search/page.tsx +++ b/web/src/app/admin/configuration/search/page.tsx @@ -24,7 +24,10 @@ export interface EmbeddingDetails { import { EmbeddingIcon } from "@/components/icons/icons"; import Link from "next/link"; -import { SavedSearchSettings } from "../../embeddings/interfaces"; +import { + getCurrentModelCopy, + SavedSearchSettings, +} from "../../embeddings/interfaces"; import UpgradingPage from "./UpgradingPage"; import { useContext } from "react"; import { SettingsContext } from "@/components/settings/SettingsProvider"; @@ -75,20 +78,6 @@ function Main() { } const currentModelName = currentEmeddingModel?.model_name; - const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap( - (provider) => - provider.embedding_models.map((model) => ({ - ...model, - provider_type: provider.provider_type, - model_name: model.model_name, // Ensure model_name is set for consistency - })) - ); - - const currentModel: CloudEmbeddingModel | HostedEmbeddingModel = - AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || - AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find( - (model) => model.model_name === currentEmeddingModel.model_name - )!; return (
@@ -102,8 +91,8 @@ function Main() { )} Embedding Model - {currentModel ? ( - + {currentEmeddingModel ? ( + ) : ( Choose your Embedding Model )} diff --git a/web/src/app/admin/embeddings/interfaces.ts b/web/src/app/admin/embeddings/interfaces.ts index c3e0395af..2fc328eab 100644 --- a/web/src/app/admin/embeddings/interfaces.ts +++ b/web/src/app/admin/embeddings/interfaces.ts @@ -1,4 +1,10 @@ -import { EmbeddingProvider } from "@/components/embedding/interfaces"; +import { + AVAILABLE_CLOUD_PROVIDERS, + AVAILABLE_MODELS, + CloudEmbeddingModel, + EmbeddingProvider, + HostedEmbeddingModel, +} from "@/components/embedding/interfaces"; // This is a slightly differnte interface than used in the backend // but is always used in conjunction with `AdvancedSearchConfiguration` @@ -92,3 +98,24 @@ export const rerankingModels: RerankingModel[] = [ link: "https://docs.cohere.com/docs/rerank", }, ]; + +export const getCurrentModelCopy = ( + currentModelName: string +): CloudEmbeddingModel | HostedEmbeddingModel | null => { + const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap( + (provider) => + provider.embedding_models.map((model) => ({ + ...model, + provider_type: provider.provider_type, + model_name: model.model_name, + })) + ); + + return ( + AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) || + AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find( + (model) => model.model_name === currentModelName + ) || + null + ); +}; diff --git a/web/src/components/embedding/ModelSelector.tsx b/web/src/components/embedding/ModelSelector.tsx index 93be374d1..9b9455c07 100644 --- a/web/src/components/embedding/ModelSelector.tsx +++ b/web/src/components/embedding/ModelSelector.tsx @@ -1,15 +1,13 @@ +import { getCurrentModelCopy } from "@/app/admin/embeddings/interfaces"; import { - MicrosoftIcon, - NomicIcon, - OpenSourceIcon, -} from "@/components/icons/icons"; -import { + AVAILABLE_CLOUD_PROVIDERS, + AVAILABLE_MODELS, EmbeddingModelDescriptor, getIconForRerankType, getTitleForRerankType, HostedEmbeddingModel, } from "./interfaces"; -import { FiExternalLink, FiStar } from "react-icons/fi"; +import { FiExternalLink } from "react-icons/fi"; export function ModelPreview({ model, @@ -18,15 +16,17 @@ export function ModelPreview({ model: EmbeddingModelDescriptor; display?: boolean; }) { + const currentModelCopy = getCurrentModelCopy(model.model_name); + return (
{model.model_name}
- {model.description - ? model.description - : "Custom model—no description is available."} + {model.description || + currentModelCopy?.description || + "Custom model—no description is available."}
); @@ -41,6 +41,8 @@ export function ModelOption({ onSelect?: (model: HostedEmbeddingModel) => void; selected: boolean; }) { + const currentModelCopy = getCurrentModelCopy(model.model_name); + return (

- {model.description || "Custom model—no description is available."} + {model.description || + currentModelCopy?.description || + "Custom model—no description is available."}

{model.isDefault ? "Default" : "Self-hosted"}