mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 11:12:02 +01:00
Test stream + Update Copy (#2317)
* update copy + conditional ordering * answer stream checks * update * add basic tests for chat streams * slightly simplify * fix typing * quick typing updates + nits
This commit is contained in:
parent
3ff2ba7ee4
commit
e2c37d6847
160
backend/tests/integration/common_utils/managers/chat.py
Normal file
160
backend/tests/integration/common_utils/managers/chat.py
Normal file
@ -0,0 +1,160 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
from requests.models import Response
|
||||
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import StreamedResponse
|
||||
from tests.integration.common_utils.test_models import TestChatMessage
|
||||
from tests.integration.common_utils.test_models import TestChatSession
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
class ChatSessionManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
persona_id: int = -1,
|
||||
description: str = "Test chat session",
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> TestChatSession:
|
||||
chat_session_creation_req = ChatSessionCreationRequest(
|
||||
persona_id=persona_id, description=description
|
||||
)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||
json=chat_session_creation_req.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
return TestChatSession(
|
||||
id=chat_session_id, persona_id=persona_id, description=description
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_message(
|
||||
chat_session_id: int,
|
||||
message: str,
|
||||
parent_message_id: int | None = None,
|
||||
user_performing_action: TestUser | None = None,
|
||||
file_descriptors: list[FileDescriptor] = [],
|
||||
prompt_id: int | None = None,
|
||||
search_doc_ids: list[int] | None = None,
|
||||
retrieval_options: RetrievalDetails | None = None,
|
||||
query_override: str | None = None,
|
||||
regenerate: bool | None = None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
) -> StreamedResponse:
|
||||
chat_message_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
message=message,
|
||||
file_descriptors=file_descriptors or [],
|
||||
prompt_id=prompt_id,
|
||||
search_doc_ids=search_doc_ids or [],
|
||||
retrieval_options=retrieval_options,
|
||||
query_override=query_override,
|
||||
regenerate=regenerate,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message",
|
||||
json=chat_message_req.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
return ChatSessionManager.analyze_response(response)
|
||||
|
||||
@staticmethod
|
||||
def get_answer_with_quote(
|
||||
persona_id: int,
|
||||
message: str,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> StreamedResponse:
|
||||
direct_qa_request = DirectQARequest(
|
||||
messages=[ThreadMessage(message=message)],
|
||||
prompt_id=None,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/query/stream-answer-with-quote",
|
||||
json=direct_qa_request.model_dump(),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
stream=True,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return ChatSessionManager.analyze_response(response)
|
||||
|
||||
@staticmethod
|
||||
def analyze_response(response: Response) -> StreamedResponse:
|
||||
response_data = [
|
||||
json.loads(line.decode("utf-8")) for line in response.iter_lines() if line
|
||||
]
|
||||
|
||||
analyzed = StreamedResponse()
|
||||
|
||||
for data in response_data:
|
||||
if "rephrased_query" in data:
|
||||
analyzed.rephrased_query = data["rephrased_query"]
|
||||
elif "tool_name" in data:
|
||||
analyzed.tool_name = data["tool_name"]
|
||||
analyzed.tool_result = (
|
||||
data.get("tool_result")
|
||||
if analyzed.tool_name == "run_search"
|
||||
else None
|
||||
)
|
||||
elif "relevance_summaries" in data:
|
||||
analyzed.relevance_summaries = data["relevance_summaries"]
|
||||
elif "answer_piece" in data and data["answer_piece"]:
|
||||
analyzed.full_message += data["answer_piece"]
|
||||
|
||||
return analyzed
|
||||
|
||||
@staticmethod
|
||||
def get_chat_history(
|
||||
chat_session: TestChatSession,
|
||||
user_performing_action: TestUser | None = None,
|
||||
) -> list[TestChatMessage]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return [
|
||||
TestChatMessage(
|
||||
id=msg["id"],
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=msg.get("parent_message_id"),
|
||||
message=msg["message"],
|
||||
response=msg.get("response", ""),
|
||||
)
|
||||
for msg in response.json()
|
||||
]
|
@ -118,3 +118,28 @@ class TestPersona(BaseModel):
|
||||
llm_model_version_override: str | None
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
|
||||
|
||||
#
|
||||
class TestChatSession(BaseModel):
|
||||
id: int
|
||||
persona_id: int
|
||||
description: str
|
||||
|
||||
|
||||
class TestChatMessage(BaseModel):
|
||||
id: str | None = None
|
||||
chat_session_id: int
|
||||
parent_message_id: str | None
|
||||
message: str
|
||||
response: str
|
||||
|
||||
|
||||
class StreamedResponse(BaseModel):
|
||||
full_message: str = ""
|
||||
rephrased_query: str | None = None
|
||||
tool_name: str | None = None
|
||||
top_documents: list[dict[str, Any]] | None = None
|
||||
relevance_summaries: list[dict[str, Any]] | None = None
|
||||
tool_result: Any | None = None
|
||||
user: str | None = None
|
||||
|
@ -0,0 +1,25 @@
|
||||
from tests.integration.common_utils.llm import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.get_answer_with_quote(
|
||||
persona_id=test_chat_session.persona_id,
|
||||
message="Hello, this is a test.",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
response.tool_name is not None
|
||||
), "Tool name should be specified (always search)"
|
||||
assert (
|
||||
response.relevance_summaries is not None
|
||||
), "Relevance summaries should be present for all search streams"
|
||||
assert len(response.full_message) > 0, "Response message should not be empty"
|
@ -0,0 +1,19 @@
|
||||
from tests.integration.common_utils.llm import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import TestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="this is a test message",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert len(response.full_message) > 0
|
@ -24,7 +24,10 @@ export interface EmbeddingDetails {
|
||||
import { EmbeddingIcon } from "@/components/icons/icons";
|
||||
|
||||
import Link from "next/link";
|
||||
import { SavedSearchSettings } from "../../embeddings/interfaces";
|
||||
import {
|
||||
getCurrentModelCopy,
|
||||
SavedSearchSettings,
|
||||
} from "../../embeddings/interfaces";
|
||||
import UpgradingPage from "./UpgradingPage";
|
||||
import { useContext } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
@ -75,20 +78,6 @@ function Main() {
|
||||
}
|
||||
|
||||
const currentModelName = currentEmeddingModel?.model_name;
|
||||
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
||||
(provider) =>
|
||||
provider.embedding_models.map((model) => ({
|
||||
...model,
|
||||
provider_type: provider.provider_type,
|
||||
model_name: model.model_name, // Ensure model_name is set for consistency
|
||||
}))
|
||||
);
|
||||
|
||||
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
|
||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
||||
(model) => model.model_name === currentEmeddingModel.model_name
|
||||
)!;
|
||||
|
||||
return (
|
||||
<div className="h-screen">
|
||||
@ -102,8 +91,8 @@ function Main() {
|
||||
)}
|
||||
<Title className="mb-6 mt-8 !text-2xl">Embedding Model</Title>
|
||||
|
||||
{currentModel ? (
|
||||
<ModelPreview model={currentModel} display />
|
||||
{currentEmeddingModel ? (
|
||||
<ModelPreview model={currentEmeddingModel} display />
|
||||
) : (
|
||||
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
||||
)}
|
||||
|
@ -1,4 +1,10 @@
|
||||
import { EmbeddingProvider } from "@/components/embedding/interfaces";
|
||||
import {
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
AVAILABLE_MODELS,
|
||||
CloudEmbeddingModel,
|
||||
EmbeddingProvider,
|
||||
HostedEmbeddingModel,
|
||||
} from "@/components/embedding/interfaces";
|
||||
|
||||
// This is a slightly differnte interface than used in the backend
|
||||
// but is always used in conjunction with `AdvancedSearchConfiguration`
|
||||
@ -92,3 +98,24 @@ export const rerankingModels: RerankingModel[] = [
|
||||
link: "https://docs.cohere.com/docs/rerank",
|
||||
},
|
||||
];
|
||||
|
||||
export const getCurrentModelCopy = (
|
||||
currentModelName: string
|
||||
): CloudEmbeddingModel | HostedEmbeddingModel | null => {
|
||||
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
||||
(provider) =>
|
||||
provider.embedding_models.map((model) => ({
|
||||
...model,
|
||||
provider_type: provider.provider_type,
|
||||
model_name: model.model_name,
|
||||
}))
|
||||
);
|
||||
|
||||
return (
|
||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
||||
(model) => model.model_name === currentModelName
|
||||
) ||
|
||||
null
|
||||
);
|
||||
};
|
||||
|
@ -1,15 +1,13 @@
|
||||
import { getCurrentModelCopy } from "@/app/admin/embeddings/interfaces";
|
||||
import {
|
||||
MicrosoftIcon,
|
||||
NomicIcon,
|
||||
OpenSourceIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
AVAILABLE_CLOUD_PROVIDERS,
|
||||
AVAILABLE_MODELS,
|
||||
EmbeddingModelDescriptor,
|
||||
getIconForRerankType,
|
||||
getTitleForRerankType,
|
||||
HostedEmbeddingModel,
|
||||
} from "./interfaces";
|
||||
import { FiExternalLink, FiStar } from "react-icons/fi";
|
||||
import { FiExternalLink } from "react-icons/fi";
|
||||
|
||||
export function ModelPreview({
|
||||
model,
|
||||
@ -18,15 +16,17 @@ export function ModelPreview({
|
||||
model: EmbeddingModelDescriptor;
|
||||
display?: boolean;
|
||||
}) {
|
||||
const currentModelCopy = getCurrentModelCopy(model.model_name);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`border border-border rounded shadow-md ${display ? "bg-inverted rounded-lg p-4" : "bg-hover-light p-2"} w-96 flex flex-col`}
|
||||
>
|
||||
<div className="font-bold text-lg flex">{model.model_name}</div>
|
||||
<div className="text-sm mt-1 mx-1">
|
||||
{model.description
|
||||
? model.description
|
||||
: "Custom model—no description is available."}
|
||||
{model.description ||
|
||||
currentModelCopy?.description ||
|
||||
"Custom model—no description is available."}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
@ -41,6 +41,8 @@ export function ModelOption({
|
||||
onSelect?: (model: HostedEmbeddingModel) => void;
|
||||
selected: boolean;
|
||||
}) {
|
||||
const currentModelCopy = getCurrentModelCopy(model.model_name);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
|
||||
@ -65,7 +67,9 @@ export function ModelOption({
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm k text-gray-600 text-left mb-2">
|
||||
{model.description || "Custom model—no description is available."}
|
||||
{model.description ||
|
||||
currentModelCopy?.description ||
|
||||
"Custom model—no description is available."}
|
||||
</p>
|
||||
<div className="text-xs text-gray-500">
|
||||
{model.isDefault ? "Default" : "Self-hosted"}
|
||||
|
Loading…
x
Reference in New Issue
Block a user