mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-22 05:50:57 +02:00
add send-message-simple-with-history endpoint to avoid… (#2101)
* add send-message-simple-with-history endpoint to support ramp. avoids bad json output in models and allows client to pass history in instead of maintaining it in our own session * slightly better error checking * addressing code review * reject on any empty message * update test naming
This commit is contained in:
parent
c7e5b11c63
commit
e517f47a89
@ -8,18 +8,31 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.danswer.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.danswer.server.query_and_chat.models import ChatBasicResponse
|
||||
from ee.danswer.server.query_and_chat.models import SimpleDoc
|
||||
|
||||
@ -122,3 +135,131 @@ def handle_simplified_chat_message(
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/send-message-simple-with-history")
|
||||
def handle_send_message_simple_with_history(
|
||||
req: BasicCreateChatMessageWithHistoryRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatBasicResponse:
|
||||
"""This is a Non-Streaming version that only gives back a minimal set of information.
|
||||
takes in chat history maintained by the caller
|
||||
and does query rephrasing similar to answer-with-quote"""
|
||||
|
||||
if len(req.messages) == 0:
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="One or more chat messages were empty"
|
||||
)
|
||||
|
||||
if msg.role != expected_role:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
|
||||
)
|
||||
if expected_role == MessageType.USER:
|
||||
expected_role = MessageType.ASSISTANT
|
||||
else:
|
||||
expected_role = MessageType.USER
|
||||
|
||||
query = req.messages[-1].message
|
||||
msg_history = req.messages[:-1]
|
||||
|
||||
logger.info(f"Received new simple with history chat message: {query}")
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="handle_send_message_simple_with_history",
|
||||
user_id=user_id,
|
||||
persona_id=req.persona_id,
|
||||
one_shot=False,
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
)
|
||||
max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
chat_message = root_message
|
||||
for msg in msg_history:
|
||||
chat_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=chat_message,
|
||||
prompt_id=req.prompt_id,
|
||||
message=msg.message,
|
||||
token_count=len(llm_tokenizer.encode(msg.message)),
|
||||
message_type=msg.role,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=msg_history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = req.query_override or thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=chat_message.id,
|
||||
message=rephrased_query,
|
||||
file_descriptors=[],
|
||||
prompt_id=req.prompt_id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=req.retrieval_options,
|
||||
query_override=rephrased_query,
|
||||
chunks_above=req.chunks_above,
|
||||
chunks_below=req.chunks_below,
|
||||
full_doc=req.full_doc,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=full_chat_msg_info,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
response = ChatBasicResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.simple_search_docs = translate_doc_response_to_simple_doc(packet)
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_chunks_indices = packet.relevant_chunk_indices
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
return response
|
||||
|
@ -1,6 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
@ -45,6 +47,16 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
search_doc_ids: list[int] | None = None
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
# Last element is the new query. All previous elements are historical context
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
query_override: str | None = None
|
||||
skip_rerank: bool | None = None
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
id: str
|
||||
semantic_identifier: str
|
||||
@ -61,3 +73,4 @@ class ChatBasicResponse(BaseModel):
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
error_msg: str | None = None
|
||||
message_id: int | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
104
backend/tests/api/test_api.py
Normal file
104
backend/tests/api/test_api.py
Normal file
@ -0,0 +1,104 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from danswer.main import fetch_versioned_implementation
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client() -> Generator[TestClient, Any, None]:
|
||||
# Set environment variables
|
||||
os.environ["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "True"
|
||||
|
||||
# Initialize TestClient with the FastAPI app
|
||||
app = fetch_versioned_implementation(
|
||||
module="danswer.main", attribute="get_application"
|
||||
)()
|
||||
client = TestClient(app)
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable when we have a testing environment with preloaded data"
|
||||
)
|
||||
def test_handle_simplified_chat_message(client: TestClient) -> None:
|
||||
req: dict[str, Any] = {}
|
||||
|
||||
req["persona_id"] = 0
|
||||
req["description"] = "pytest"
|
||||
response = client.post("/chat/create-chat-session", json=req)
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
|
||||
req = {}
|
||||
req["chat_session_id"] = chat_session_id
|
||||
req["message"] = "hello"
|
||||
|
||||
response = client.post("/chat/send-message-simple-api", json=req)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable when we have a testing environment with preloaded data"
|
||||
)
|
||||
def test_handle_send_message_simple_with_history(client: TestClient) -> None:
|
||||
req: dict[str, Any] = {}
|
||||
messages = []
|
||||
messages.append({"message": "What sorts of questions can you answer for me?"})
|
||||
# messages.append({"message":
|
||||
# "I'd be happy to assist you with a wide range of questions related to Ramp's expense management platform. "
|
||||
# "I can help with topics such as:\n\n"
|
||||
# "1. Setting up and managing your Ramp account\n"
|
||||
# "2. Using Ramp cards and making purchases\n"
|
||||
# "3. Submitting and reviewing expenses\n"
|
||||
# "4. Understanding Ramp's features and benefits\n"
|
||||
# "5. Navigating the Ramp dashboard and mobile app\n"
|
||||
# "6. Managing team spending and budgets\n"
|
||||
# "7. Integrating Ramp with accounting software\n"
|
||||
# "8. Troubleshooting common issues\n\n"
|
||||
# "Feel free to ask any specific questions you have about using Ramp, "
|
||||
# "and I'll do my best to provide clear and helpful answers. "
|
||||
# "Is there a particular area you'd like to know more about?",
|
||||
# "role": "assistant"})
|
||||
# req["prompt_id"] = 9
|
||||
# req["persona_id"] = 6
|
||||
|
||||
# Yoda
|
||||
req["persona_id"] = 1
|
||||
req["prompt_id"] = 4
|
||||
messages.append(
|
||||
{
|
||||
"message": "Answer questions for you, I can. "
|
||||
"About many topics, knowledge I have. "
|
||||
"But specific to documents provided, limited my responses are. "
|
||||
"Ask you may about:\n\n"
|
||||
"- User interviews and building trust with participants\n"
|
||||
"- Designing effective surveys and survey questions \n"
|
||||
"- Product analysis approaches\n"
|
||||
"- Recruiting participants for research\n"
|
||||
"- Discussion guides for user interviews\n"
|
||||
"- Types of survey questions\n\n"
|
||||
"More there may be, but focus on these areas, the given context does. "
|
||||
"Specific questions you have, ask you should. Guide you I will, as best I can.",
|
||||
"role": "assistant",
|
||||
}
|
||||
)
|
||||
# messages.append({"message": "Where can I pilot a survey?"})
|
||||
|
||||
# messages.append({"message": "How many data points should I collect to validate my solution?"})
|
||||
messages.append({"message": "What is solution validation research used for?"})
|
||||
|
||||
req["messages"] = messages
|
||||
|
||||
response = client.post("/chat/send-message-simple-with-history", json=req)
|
||||
assert response.status_code == 200
|
||||
|
||||
resp_json = response.json()
|
||||
|
||||
# persona must have LLM relevance enabled for this to pass
|
||||
assert len(resp_json["llm_chunks_indices"]) > 0
|
Loading…
x
Reference in New Issue
Block a user