Test stream + Update Copy (#2317)

* update copy + conditional ordering

* answer stream checks

* update

* add basic tests for chat streams

* slightly simplify

* fix typing

* quick typing updates + nits
This commit is contained in:
pablodanswer 2024-09-15 12:40:48 -07:00 committed by GitHub
parent 3ff2ba7ee4
commit e2c37d6847
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 277 additions and 28 deletions

View File

@ -0,0 +1,160 @@
import json
import requests
from requests.models import Response
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.models import RetrievalDetails
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import StreamedResponse
from tests.integration.common_utils.test_models import TestChatMessage
from tests.integration.common_utils.test_models import TestChatSession
from tests.integration.common_utils.test_models import TestUser
class ChatSessionManager:
@staticmethod
def create(
persona_id: int = -1,
description: str = "Test chat session",
user_performing_action: TestUser | None = None,
) -> TestChatSession:
chat_session_creation_req = ChatSessionCreationRequest(
persona_id=persona_id, description=description
)
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session",
json=chat_session_creation_req.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
chat_session_id = response.json()["chat_session_id"]
return TestChatSession(
id=chat_session_id, persona_id=persona_id, description=description
)
@staticmethod
def send_message(
chat_session_id: int,
message: str,
parent_message_id: int | None = None,
user_performing_action: TestUser | None = None,
file_descriptors: list[FileDescriptor] = [],
prompt_id: int | None = None,
search_doc_ids: list[int] | None = None,
retrieval_options: RetrievalDetails | None = None,
query_override: str | None = None,
regenerate: bool | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
) -> StreamedResponse:
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=parent_message_id,
message=message,
file_descriptors=file_descriptors or [],
prompt_id=prompt_id,
search_doc_ids=search_doc_ids or [],
retrieval_options=retrieval_options,
query_override=query_override,
regenerate=regenerate,
llm_override=llm_override,
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message",
json=chat_message_req.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
stream=True,
)
return ChatSessionManager.analyze_response(response)
@staticmethod
def get_answer_with_quote(
persona_id: int,
message: str,
user_performing_action: TestUser | None = None,
) -> StreamedResponse:
direct_qa_request = DirectQARequest(
messages=[ThreadMessage(message=message)],
prompt_id=None,
persona_id=persona_id,
)
response = requests.post(
f"{API_SERVER_URL}/query/stream-answer-with-quote",
json=direct_qa_request.model_dump(),
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
stream=True,
)
response.raise_for_status()
return ChatSessionManager.analyze_response(response)
@staticmethod
def analyze_response(response: Response) -> StreamedResponse:
response_data = [
json.loads(line.decode("utf-8")) for line in response.iter_lines() if line
]
analyzed = StreamedResponse()
for data in response_data:
if "rephrased_query" in data:
analyzed.rephrased_query = data["rephrased_query"]
elif "tool_name" in data:
analyzed.tool_name = data["tool_name"]
analyzed.tool_result = (
data.get("tool_result")
if analyzed.tool_name == "run_search"
else None
)
elif "relevance_summaries" in data:
analyzed.relevance_summaries = data["relevance_summaries"]
elif "answer_piece" in data and data["answer_piece"]:
analyzed.full_message += data["answer_piece"]
return analyzed
@staticmethod
def get_chat_history(
chat_session: TestChatSession,
user_performing_action: TestUser | None = None,
) -> list[TestChatMessage]:
response = requests.get(
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
headers=user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS,
)
response.raise_for_status()
return [
TestChatMessage(
id=msg["id"],
chat_session_id=chat_session.id,
parent_message_id=msg.get("parent_message_id"),
message=msg["message"],
response=msg.get("response", ""),
)
for msg in response.json()
]

View File

@ -118,3 +118,28 @@ class TestPersona(BaseModel):
llm_model_version_override: str | None
users: list[str]
groups: list[int]
#
class TestChatSession(BaseModel):
id: int
persona_id: int
description: str
class TestChatMessage(BaseModel):
id: str | None = None
chat_session_id: int
parent_message_id: str | None
message: str
response: str
class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[dict[str, Any]] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None

View File

@ -0,0 +1,25 @@
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestUser
def test_send_message_simple_with_history(reset: None) -> None:
admin_user: TestUser = UserManager.create(name="admin_user")
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.get_answer_with_quote(
persona_id=test_chat_session.persona_id,
message="Hello, this is a test.",
user_performing_action=admin_user,
)
assert (
response.tool_name is not None
), "Tool name should be specified (always search)"
assert (
response.relevance_summaries is not None
), "Relevance summaries should be present for all search streams"
assert len(response.full_message) > 0, "Response message should not be empty"

View File

@ -0,0 +1,19 @@
from tests.integration.common_utils.llm import LLMProviderManager
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import TestUser
def test_send_message_simple_with_history(reset: None) -> None:
admin_user: TestUser = UserManager.create(name="admin_user")
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message="this is a test message",
user_performing_action=admin_user,
)
assert len(response.full_message) > 0

View File

@ -24,7 +24,10 @@ export interface EmbeddingDetails {
import { EmbeddingIcon } from "@/components/icons/icons";
import Link from "next/link";
import { SavedSearchSettings } from "../../embeddings/interfaces";
import {
getCurrentModelCopy,
SavedSearchSettings,
} from "../../embeddings/interfaces";
import UpgradingPage from "./UpgradingPage";
import { useContext } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
@ -75,20 +78,6 @@ function Main() {
}
const currentModelName = currentEmeddingModel?.model_name;
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
(provider) =>
provider.embedding_models.map((model) => ({
...model,
provider_type: provider.provider_type,
model_name: model.model_name, // Ensure model_name is set for consistency
}))
);
const currentModel: CloudEmbeddingModel | HostedEmbeddingModel =
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
(model) => model.model_name === currentEmeddingModel.model_name
)!;
return (
<div className="h-screen">
@ -102,8 +91,8 @@ function Main() {
)}
<Title className="mb-6 mt-8 !text-2xl">Embedding Model</Title>
{currentModel ? (
<ModelPreview model={currentModel} display />
{currentEmeddingModel ? (
<ModelPreview model={currentEmeddingModel} display />
) : (
<Title className="mt-8 mb-4">Choose your Embedding Model</Title>
)}

View File

@ -1,4 +1,10 @@
import { EmbeddingProvider } from "@/components/embedding/interfaces";
import {
AVAILABLE_CLOUD_PROVIDERS,
AVAILABLE_MODELS,
CloudEmbeddingModel,
EmbeddingProvider,
HostedEmbeddingModel,
} from "@/components/embedding/interfaces";
// This is a slightly differnte interface than used in the backend
// but is always used in conjunction with `AdvancedSearchConfiguration`
@ -92,3 +98,24 @@ export const rerankingModels: RerankingModel[] = [
link: "https://docs.cohere.com/docs/rerank",
},
];
export const getCurrentModelCopy = (
currentModelName: string
): CloudEmbeddingModel | HostedEmbeddingModel | null => {
const AVAILABLE_CLOUD_PROVIDERS_FLATTENED = AVAILABLE_CLOUD_PROVIDERS.flatMap(
(provider) =>
provider.embedding_models.map((model) => ({
...model,
provider_type: provider.provider_type,
model_name: model.model_name,
}))
);
return (
AVAILABLE_MODELS.find((model) => model.model_name === currentModelName) ||
AVAILABLE_CLOUD_PROVIDERS_FLATTENED.find(
(model) => model.model_name === currentModelName
) ||
null
);
};

View File

@ -1,15 +1,13 @@
import { getCurrentModelCopy } from "@/app/admin/embeddings/interfaces";
import {
MicrosoftIcon,
NomicIcon,
OpenSourceIcon,
} from "@/components/icons/icons";
import {
AVAILABLE_CLOUD_PROVIDERS,
AVAILABLE_MODELS,
EmbeddingModelDescriptor,
getIconForRerankType,
getTitleForRerankType,
HostedEmbeddingModel,
} from "./interfaces";
import { FiExternalLink, FiStar } from "react-icons/fi";
import { FiExternalLink } from "react-icons/fi";
export function ModelPreview({
model,
@ -18,15 +16,17 @@ export function ModelPreview({
model: EmbeddingModelDescriptor;
display?: boolean;
}) {
const currentModelCopy = getCurrentModelCopy(model.model_name);
return (
<div
className={`border border-border rounded shadow-md ${display ? "bg-inverted rounded-lg p-4" : "bg-hover-light p-2"} w-96 flex flex-col`}
>
<div className="font-bold text-lg flex">{model.model_name}</div>
<div className="text-sm mt-1 mx-1">
{model.description
? model.description
: "Custom model—no description is available."}
{model.description ||
currentModelCopy?.description ||
"Custom model—no description is available."}
</div>
</div>
);
@ -41,6 +41,8 @@ export function ModelOption({
onSelect?: (model: HostedEmbeddingModel) => void;
selected: boolean;
}) {
const currentModelCopy = getCurrentModelCopy(model.model_name);
return (
<div
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
@ -65,7 +67,9 @@ export function ModelOption({
)}
</div>
<p className="text-sm k text-gray-600 text-left mb-2">
{model.description || "Custom model—no description is available."}
{model.description ||
currentModelCopy?.description ||
"Custom model—no description is available."}
</p>
<div className="text-xs text-gray-500">
{model.isDefault ? "Default" : "Self-hosted"}