mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-17 11:30:58 +02:00
Test stream + Update Copy (#2317)
* update copy + conditional ordering * answer stream checks * update * add basic tests for chat streams * slightly simplify * fix typing * quick typing updates + nits
This commit is contained in:
parent
3ff2ba7ee4
commit
e2c37d6847
160
backend/tests/integration/common_utils/managers/chat.py
Normal file
160
backend/tests/integration/common_utils/managers/chat.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.models import Response
|
||||||
|
|
||||||
|
from danswer.file_store.models import FileDescriptor
|
||||||
|
from danswer.llm.override_models import LLMOverride
|
||||||
|
from danswer.llm.override_models import PromptOverride
|
||||||
|
from danswer.one_shot_answer.models import DirectQARequest
|
||||||
|
from danswer.one_shot_answer.models import ThreadMessage
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
||||||
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import StreamedResponse
|
||||||
|
from tests.integration.common_utils.test_models import TestChatMessage
|
||||||
|
from tests.integration.common_utils.test_models import TestChatSession
|
||||||
|
from tests.integration.common_utils.test_models import TestUser
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSessionManager:
|
||||||
|
@staticmethod
|
||||||
|
def create(
|
||||||
|
persona_id: int = -1,
|
||||||
|
description: str = "Test chat session",
|
||||||
|
user_performing_action: TestUser | None = None,
|
||||||
|
) -> TestChatSession:
|
||||||
|
chat_session_creation_req = ChatSessionCreationRequest(
|
||||||
|
persona_id=persona_id, description=description
|
||||||
|
)
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/chat/create-chat-session",
|
||||||
|
json=chat_session_creation_req.model_dump(),
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
chat_session_id = response.json()["chat_session_id"]
|
||||||
|
return TestChatSession(
|
||||||
|
id=chat_session_id, persona_id=persona_id, description=description
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def send_message(
|
||||||
|
chat_session_id: int,
|
||||||
|
message: str,
|
||||||
|
parent_message_id: int | None = None,
|
||||||
|
user_performing_action: TestUser | None = None,
|
||||||
|
file_descriptors: list[FileDescriptor] = [],
|
||||||
|
prompt_id: int | None = None,
|
||||||
|
search_doc_ids: list[int] | None = None,
|
||||||
|
retrieval_options: RetrievalDetails | None = None,
|
||||||
|
query_override: str | None = None,
|
||||||
|
regenerate: bool | None = None,
|
||||||
|
llm_override: LLMOverride | None = None,
|
||||||
|
prompt_override: PromptOverride | None = None,
|
||||||
|
alternate_assistant_id: int | None = None,
|
||||||
|
use_existing_user_message: bool = False,
|
||||||
|
) -> StreamedResponse:
|
||||||
|
chat_message_req = CreateChatMessageRequest(
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
parent_message_id=parent_message_id,
|
||||||
|
message=message,
|
||||||
|
file_descriptors=file_descriptors or [],
|
||||||
|
prompt_id=prompt_id,
|
||||||
|
search_doc_ids=search_doc_ids or [],
|
||||||
|
retrieval_options=retrieval_options,
|
||||||
|
query_override=query_override,
|
||||||
|
regenerate=regenerate,
|
||||||
|
llm_override=llm_override,
|
||||||
|
prompt_override=prompt_override,
|
||||||
|
alternate_assistant_id=alternate_assistant_id,
|
||||||
|
use_existing_user_message=use_existing_user_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/chat/send-message",
|
||||||
|
json=chat_message_req.model_dump(),
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSessionManager.analyze_response(response)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_answer_with_quote(
|
||||||
|
persona_id: int,
|
||||||
|
message: str,
|
||||||
|
user_performing_action: TestUser | None = None,
|
||||||
|
) -> StreamedResponse:
|
||||||
|
direct_qa_request = DirectQARequest(
|
||||||
|
messages=[ThreadMessage(message=message)],
|
||||||
|
prompt_id=None,
|
||||||
|
persona_id=persona_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{API_SERVER_URL}/query/stream-answer-with-quote",
|
||||||
|
json=direct_qa_request.model_dump(),
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return ChatSessionManager.analyze_response(response)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def analyze_response(response: Response) -> StreamedResponse:
|
||||||
|
response_data = [
|
||||||
|
json.loads(line.decode("utf-8")) for line in response.iter_lines() if line
|
||||||
|
]
|
||||||
|
|
||||||
|
analyzed = StreamedResponse()
|
||||||
|
|
||||||
|
for data in response_data:
|
||||||
|
if "rephrased_query" in data:
|
||||||
|
analyzed.rephrased_query = data["rephrased_query"]
|
||||||
|
elif "tool_name" in data:
|
||||||
|
analyzed.tool_name = data["tool_name"]
|
||||||
|
analyzed.tool_result = (
|
||||||
|
data.get("tool_result")
|
||||||
|
if analyzed.tool_name == "run_search"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
elif "relevance_summaries" in data:
|
||||||
|
analyzed.relevance_summaries = data["relevance_summaries"]
|
||||||
|
elif "answer_piece" in data and data["answer_piece"]:
|
||||||
|
analyzed.full_message += data["answer_piece"]
|
||||||
|
|
||||||
|
return analyzed
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_chat_history(
|
||||||
|
chat_session: TestChatSession,
|
||||||
|
user_performing_action: TestUser | None = None,
|
||||||
|
) -> list[TestChatMessage]:
|
||||||
|
response = requests.get(
|
||||||
|
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
|
||||||
|
headers=user_performing_action.headers
|
||||||
|
if user_performing_action
|
||||||
|
else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return [
|
||||||
|
TestChatMessage(
|
||||||
|
id=msg["id"],
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
parent_message_id=msg.get("parent_message_id"),
|
||||||
|
message=msg["message"],
|
||||||
|
response=msg.get("response", ""),
|
||||||
|
)
|
||||||
|
for msg in response.json()
|
||||||
|
]
|
@ -118,3 +118,28 @@ class TestPersona(BaseModel):
|
|||||||
llm_model_version_override: str | None
|
llm_model_version_override: str | None
|
||||||
users: list[str]
|
users: list[str]
|
||||||
groups: list[int]
|
groups: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
class TestChatSession(BaseModel):
|
||||||
|
id: int
|
||||||
|
persona_id: int
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatMessage(BaseModel):
|
||||||
|
id: str | None = None
|
||||||
|
chat_session_id: int
|
||||||
|
parent_message_id: str | None
|
||||||
|
message: str
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
class StreamedResponse(BaseModel):
|
||||||
|
full_message: str = ""
|
||||||
|
rephrased_query: str | None = None
|
||||||
|
tool_name: str | None = None
|
||||||
|
top_documents: list[dict[str, Any]] | None = None
|
||||||
|
relevance_summaries: list[dict[str, Any]] | None = None
|
||||||
|
tool_result: Any | None = None
|
||||||
|
user: str | None = None
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
from tests.integration.common_utils.llm import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import TestUser
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_message_simple_with_history(reset: None) -> None:
|
||||||
|
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||||
|
LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
response = ChatSessionManager.get_answer_with_quote(
|
||||||
|
persona_id=test_chat_session.persona_id,
|
||||||
|
message="Hello, this is a test.",
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
response.tool_name is not None
|
||||||
|
), "Tool name should be specified (always search)"
|
||||||
|
assert (
|
||||||
|
response.relevance_summaries is not None
|
||||||
|
), "Relevance summaries should be present for all search streams"
|
||||||
|
assert len(response.full_message) > 0, "Response message should not be empty"
|
@ -0,0 +1,19 @@
|
|||||||
|
from tests.integration.common_utils.llm import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import TestUser
|
||||||
|
|
||||||
|
|
||||||
|
def test_send_message_simple_with_history(reset: None) -> None:
|
||||||
|
admin_user: TestUser = UserManager.create(name="admin_user")
|
||||||
|
LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
response = ChatSessionManager.send_message(
|
||||||
|
chat_session_id=test_chat_session.id,
|
||||||
|
message="this is a test message",
|
||||||
|
user_performing_action=admin_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response.full_message) > 0
|
@ -24,7 +24,10 @@ export interface EmbeddingDetails {
|
|||||||
import { EmbeddingIcon } from "@/components/icons/icons";
|
import { EmbeddingIcon } from "@/components/icons/icons";
|
||||||
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { SavedSearchSettings } from "../../embeddings/interfaces";
|
import {
|
||||||
|
getCurrentModelCopy,
|
||||||
|
SavedSearchSettings,
|
||||||
|
} from "../../embeddings/interfaces";
|
||||||
import UpgradingPage from "./UpgradingPage";
|
import UpgradingPage from "./UpgradingPage";
|
||||||
import { useContext } from "react";
|
import { useContext } from "react";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
@ -75,20 +78,6 @@ function Main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const currentModelName = currentEmeddingModel?.model_name;
|
const currentModelName = currentEmeddingModel?.model_name;
|
||||||
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
|
||||||
(provider) =>
|
|
||||||
provider.embedding_models.map((model) => ({
|
|
||||||
...model,
|
|
||||||
provider_type: provider.provider_type,
|
|
||||||
model_name: model.model_name, // Ensure model_name is set for consistency
|
|
||||||
}))
|
|
||||||
);
|
|
||||||
|
|
||||||
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
|
|
||||||
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
|
||||||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
|
||||||
(model) => model.model_name === currentEmeddingModel.model_name
|
|
||||||
)!;
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="h-screen">
|
<div className="h-screen">
|
||||||
@ -102,8 +91,8 @@ function Main() {
|
|||||||
)}
|
)}
|
||||||
<Title className="mb-6 mt-8 !text-2xl">Embedding Model</Title>
|
<Title className="mb-6 mt-8 !text-2xl">Embedding Model</Title>
|
||||||
|
|
||||||
{currentModel ? (
|
{currentEmeddingModel ? (
|
||||||
<ModelPreview model={currentModel} display />
|
<ModelPreview model={currentEmeddingModel} display />
|
||||||
) : (
|
) : (
|
||||||
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
|
||||||
)}
|
)}
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
import { EmbeddingProvider } from "@/components/embedding/interfaces";
|
import {
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS,
|
||||||
|
AVAILABLE_MODELS,
|
||||||
|
CloudEmbeddingModel,
|
||||||
|
EmbeddingProvider,
|
||||||
|
HostedEmbeddingModel,
|
||||||
|
} from "@/components/embedding/interfaces";
|
||||||
|
|
||||||
// This is a slightly differnte interface than used in the backend
|
// This is a slightly differnte interface than used in the backend
|
||||||
// but is always used in conjunction with `AdvancedSearchConfiguration`
|
// but is always used in conjunction with `AdvancedSearchConfiguration`
|
||||||
@ -92,3 +98,24 @@ export const rerankingModels: RerankingModel[] = [
|
|||||||
link: "https://docs.cohere.com/docs/rerank",
|
link: "https://docs.cohere.com/docs/rerank",
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|
||||||
|
export const getCurrentModelCopy = (
|
||||||
|
currentModelName: string
|
||||||
|
): CloudEmbeddingModel | HostedEmbeddingModel | null => {
|
||||||
|
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
|
||||||
|
(provider) =>
|
||||||
|
provider.embedding_models.map((model) => ({
|
||||||
|
...model,
|
||||||
|
provider_type: provider.provider_type,
|
||||||
|
model_name: model.model_name,
|
||||||
|
}))
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
|
||||||
|
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
|
||||||
|
(model) => model.model_name === currentModelName
|
||||||
|
) ||
|
||||||
|
null
|
||||||
|
);
|
||||||
|
};
|
||||||
|
@ -1,15 +1,13 @@
|
|||||||
|
import { getCurrentModelCopy } from "@/app/admin/embeddings/interfaces";
|
||||||
import {
|
import {
|
||||||
MicrosoftIcon,
|
AVAILABLE_CLOUD_PROVIDERS,
|
||||||
NomicIcon,
|
AVAILABLE_MODELS,
|
||||||
OpenSourceIcon,
|
|
||||||
} from "@/components/icons/icons";
|
|
||||||
import {
|
|
||||||
EmbeddingModelDescriptor,
|
EmbeddingModelDescriptor,
|
||||||
getIconForRerankType,
|
getIconForRerankType,
|
||||||
getTitleForRerankType,
|
getTitleForRerankType,
|
||||||
HostedEmbeddingModel,
|
HostedEmbeddingModel,
|
||||||
} from "./interfaces";
|
} from "./interfaces";
|
||||||
import { FiExternalLink, FiStar } from "react-icons/fi";
|
import { FiExternalLink } from "react-icons/fi";
|
||||||
|
|
||||||
export function ModelPreview({
|
export function ModelPreview({
|
||||||
model,
|
model,
|
||||||
@ -18,15 +16,17 @@ export function ModelPreview({
|
|||||||
model: EmbeddingModelDescriptor;
|
model: EmbeddingModelDescriptor;
|
||||||
display?: boolean;
|
display?: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const currentModelCopy = getCurrentModelCopy(model.model_name);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`border border-border rounded shadow-md ${display ? "bg-inverted rounded-lg p-4" : "bg-hover-light p-2"} w-96 flex flex-col`}
|
className={`border border-border rounded shadow-md ${display ? "bg-inverted rounded-lg p-4" : "bg-hover-light p-2"} w-96 flex flex-col`}
|
||||||
>
|
>
|
||||||
<div className="font-bold text-lg flex">{model.model_name}</div>
|
<div className="font-bold text-lg flex">{model.model_name}</div>
|
||||||
<div className="text-sm mt-1 mx-1">
|
<div className="text-sm mt-1 mx-1">
|
||||||
{model.description
|
{model.description ||
|
||||||
? model.description
|
currentModelCopy?.description ||
|
||||||
: "Custom model—no description is available."}
|
"Custom model—no description is available."}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
@ -41,6 +41,8 @@ export function ModelOption({
|
|||||||
onSelect?: (model: HostedEmbeddingModel) => void;
|
onSelect?: (model: HostedEmbeddingModel) => void;
|
||||||
selected: boolean;
|
selected: boolean;
|
||||||
}) {
|
}) {
|
||||||
|
const currentModelCopy = getCurrentModelCopy(model.model_name);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
|
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
|
||||||
@ -65,7 +67,9 @@ export function ModelOption({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<p className="text-sm k text-gray-600 text-left mb-2">
|
<p className="text-sm k text-gray-600 text-left mb-2">
|
||||||
{model.description || "Custom model—no description is available."}
|
{model.description ||
|
||||||
|
currentModelCopy?.description ||
|
||||||
|
"Custom model—no description is available."}
|
||||||
</p>
|
</p>
|
||||||
<div className="text-xs text-gray-500">
|
<div className="text-xs text-gray-500">
|
||||||
{model.isDefault ? "Default" : "Self-hosted"}
|
{model.isDefault ? "Default" : "Self-hosted"}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user