mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 13:20:15 +02:00
* rename classes and ignore deprecation warnings we mostly don't have control over * copy pytest.ini * ignore CryptographyDeprecationWarning * fully qualify the warning
161 lines
5.9 KiB
Python
161 lines
5.9 KiB
Python
import json
|
|
|
|
import requests
|
|
from requests.models import Response
|
|
|
|
from danswer.file_store.models import FileDescriptor
|
|
from danswer.llm.override_models import LLMOverride
|
|
from danswer.llm.override_models import PromptOverride
|
|
from danswer.one_shot_answer.models import DirectQARequest
|
|
from danswer.one_shot_answer.models import ThreadMessage
|
|
from danswer.search.models import RetrievalDetails
|
|
from danswer.server.query_and_chat.models import ChatSessionCreationRequest
|
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
|
from tests.integration.common_utils.test_models import DATestChatMessage
|
|
from tests.integration.common_utils.test_models import DATestChatSession
|
|
from tests.integration.common_utils.test_models import DATestUser
|
|
from tests.integration.common_utils.test_models import StreamedResponse
|
|
|
|
|
|
class ChatSessionManager:
|
|
@staticmethod
|
|
def create(
|
|
persona_id: int = -1,
|
|
description: str = "Test chat session",
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> DATestChatSession:
|
|
chat_session_creation_req = ChatSessionCreationRequest(
|
|
persona_id=persona_id, description=description
|
|
)
|
|
response = requests.post(
|
|
f"{API_SERVER_URL}/chat/create-chat-session",
|
|
json=chat_session_creation_req.model_dump(),
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
chat_session_id = response.json()["chat_session_id"]
|
|
return DATestChatSession(
|
|
id=chat_session_id, persona_id=persona_id, description=description
|
|
)
|
|
|
|
@staticmethod
|
|
def send_message(
|
|
chat_session_id: int,
|
|
message: str,
|
|
parent_message_id: int | None = None,
|
|
user_performing_action: DATestUser | None = None,
|
|
file_descriptors: list[FileDescriptor] = [],
|
|
prompt_id: int | None = None,
|
|
search_doc_ids: list[int] | None = None,
|
|
retrieval_options: RetrievalDetails | None = None,
|
|
query_override: str | None = None,
|
|
regenerate: bool | None = None,
|
|
llm_override: LLMOverride | None = None,
|
|
prompt_override: PromptOverride | None = None,
|
|
alternate_assistant_id: int | None = None,
|
|
use_existing_user_message: bool = False,
|
|
) -> StreamedResponse:
|
|
chat_message_req = CreateChatMessageRequest(
|
|
chat_session_id=chat_session_id,
|
|
parent_message_id=parent_message_id,
|
|
message=message,
|
|
file_descriptors=file_descriptors or [],
|
|
prompt_id=prompt_id,
|
|
search_doc_ids=search_doc_ids or [],
|
|
retrieval_options=retrieval_options,
|
|
query_override=query_override,
|
|
regenerate=regenerate,
|
|
llm_override=llm_override,
|
|
prompt_override=prompt_override,
|
|
alternate_assistant_id=alternate_assistant_id,
|
|
use_existing_user_message=use_existing_user_message,
|
|
)
|
|
|
|
response = requests.post(
|
|
f"{API_SERVER_URL}/chat/send-message",
|
|
json=chat_message_req.model_dump(),
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
stream=True,
|
|
)
|
|
|
|
return ChatSessionManager.analyze_response(response)
|
|
|
|
@staticmethod
|
|
def get_answer_with_quote(
|
|
persona_id: int,
|
|
message: str,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> StreamedResponse:
|
|
direct_qa_request = DirectQARequest(
|
|
messages=[ThreadMessage(message=message)],
|
|
prompt_id=None,
|
|
persona_id=persona_id,
|
|
)
|
|
|
|
response = requests.post(
|
|
f"{API_SERVER_URL}/query/stream-answer-with-quote",
|
|
json=direct_qa_request.model_dump(),
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
stream=True,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return ChatSessionManager.analyze_response(response)
|
|
|
|
@staticmethod
|
|
def analyze_response(response: Response) -> StreamedResponse:
|
|
response_data = [
|
|
json.loads(line.decode("utf-8")) for line in response.iter_lines() if line
|
|
]
|
|
|
|
analyzed = StreamedResponse()
|
|
|
|
for data in response_data:
|
|
if "rephrased_query" in data:
|
|
analyzed.rephrased_query = data["rephrased_query"]
|
|
elif "tool_name" in data:
|
|
analyzed.tool_name = data["tool_name"]
|
|
analyzed.tool_result = (
|
|
data.get("tool_result")
|
|
if analyzed.tool_name == "run_search"
|
|
else None
|
|
)
|
|
elif "relevance_summaries" in data:
|
|
analyzed.relevance_summaries = data["relevance_summaries"]
|
|
elif "answer_piece" in data and data["answer_piece"]:
|
|
analyzed.full_message += data["answer_piece"]
|
|
|
|
return analyzed
|
|
|
|
@staticmethod
|
|
def get_chat_history(
|
|
chat_session: DATestChatSession,
|
|
user_performing_action: DATestUser | None = None,
|
|
) -> list[DATestChatMessage]:
|
|
response = requests.get(
|
|
f"{API_SERVER_URL}/chat/history/{chat_session.id}",
|
|
headers=user_performing_action.headers
|
|
if user_performing_action
|
|
else GENERAL_HEADERS,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return [
|
|
DATestChatMessage(
|
|
id=msg["id"],
|
|
chat_session_id=chat_session.id,
|
|
parent_message_id=msg.get("parent_message_id"),
|
|
message=msg["message"],
|
|
response=msg.get("response", ""),
|
|
)
|
|
for msg in response.json()
|
|
]
|