add send-message-simple-with-history endpoint to avoid… (#2101)

* add send-message-simple-with-history endpoint to support ramp. avoids bad json output in models and allows client to pass history in instead of maintaining it in our own session

* slightly better error checking

* addressing code review

* reject on any empty message

* update test naming
This commit is contained in:
rkuo-danswer 2024-08-11 20:33:52 -07:00 committed by GitHub
parent c7e5b11c63
commit e517f47a89
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 258 additions and 0 deletions

View File

@ -8,18 +8,31 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError from danswer.chat.models import StreamingError
from danswer.chat.process_message import stream_chat_message_objects from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import User from danswer.db.models import User
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import get_max_input_tokens
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.one_shot_answer.qa_utils import combine_message_thread
from danswer.search.models import OptionalSearchSetting from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest from ee.danswer.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.danswer.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.danswer.server.query_and_chat.models import ChatBasicResponse from ee.danswer.server.query_and_chat.models import ChatBasicResponse
from ee.danswer.server.query_and_chat.models import SimpleDoc from ee.danswer.server.query_and_chat.models import SimpleDoc
@ -122,3 +135,131 @@ def handle_simplified_chat_message(
response.answer_citationless = remove_answer_citations(answer) response.answer_citationless = remove_answer_citations(answer)
return response return response
@router.post("/send-message-simple-with-history")
def handle_send_message_simple_with_history(
req: BasicCreateChatMessageWithHistoryRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information.
takes in chat history maintained by the caller
and does query rephrasing similar to answer-with-quote"""
if len(req.messages) == 0:
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
raise HTTPException(
status_code=400, detail="One or more chat messages were empty"
)
if msg.role != expected_role:
raise HTTPException(
status_code=400,
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
)
if expected_role == MessageType.USER:
expected_role = MessageType.ASSISTANT
else:
expected_role = MessageType.USER
query = req.messages[-1].message
msg_history = req.messages[:-1]
logger.info(f"Received new simple with history chat message: {query}")
user_id = user.id if user is not None else None
chat_session = create_chat_session(
db_session=db_session,
description="handle_send_message_simple_with_history",
user_id=user_id,
persona_id=req.persona_id,
one_shot=False,
)
llm, _ = get_llms_for_persona(persona=chat_session.persona)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
)
max_history_tokens = int(input_tokens * DANSWER_BOT_TARGET_CHUNK_PERCENTAGE)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
chat_message = root_message
for msg in msg_history:
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=chat_message,
prompt_id=req.prompt_id,
message=msg.message,
token_count=len(llm_tokenizer.encode(msg.message)),
message_type=msg.role,
db_session=db_session,
commit=False,
)
db_session.commit()
history_str = combine_message_thread(
messages=msg_history,
max_tokens=max_history_tokens,
llm_tokenizer=llm_tokenizer,
)
rephrased_query = req.query_override or thread_based_query_rephrase(
user_query=query,
history_str=history_str,
)
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=chat_message.id,
message=rephrased_query,
file_descriptors=[],
prompt_id=req.prompt_id,
search_doc_ids=None,
retrieval_options=req.retrieval_options,
query_override=rephrased_query,
chunks_above=req.chunks_above,
chunks_below=req.chunks_below,
full_doc=req.full_doc,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
)
response = ChatBasicResponse()
answer = ""
for packet in packets:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.simple_search_docs = translate_doc_response_to_simple_doc(packet)
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_chunks_indices = packet.relevant_chunk_indices
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
return response

View File

@ -1,6 +1,8 @@
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import Field
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.enums import LLMEvaluationType from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import SearchType from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext from danswer.search.models import ChunkContext
@ -45,6 +47,16 @@ class BasicCreateChatMessageRequest(ChunkContext):
search_doc_ids: list[int] | None = None search_doc_ids: list[int] | None = None
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
# Last element is the new query. All previous elements are historical context
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
query_override: str | None = None
skip_rerank: bool | None = None
class SimpleDoc(BaseModel): class SimpleDoc(BaseModel):
id: str id: str
semantic_identifier: str semantic_identifier: str
@ -61,3 +73,4 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None simple_search_docs: list[SimpleDoc] | None = None
error_msg: str | None = None error_msg: str | None = None
message_id: int | None = None message_id: int | None = None
llm_chunks_indices: list[int] | None = None

View File

@ -0,0 +1,104 @@
import os
from collections.abc import Generator
from typing import Any
import pytest
from fastapi.testclient import TestClient
from danswer.main import fetch_versioned_implementation
from danswer.utils.logger import setup_logger
logger = setup_logger()
@pytest.fixture(scope="function")
def client() -> Generator[TestClient, Any, None]:
# Set environment variables
os.environ["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "True"
# Initialize TestClient with the FastAPI app
app = fetch_versioned_implementation(
module="danswer.main", attribute="get_application"
)()
client = TestClient(app)
yield client
@pytest.mark.skip(
reason="enable when we have a testing environment with preloaded data"
)
def test_handle_simplified_chat_message(client: TestClient) -> None:
req: dict[str, Any] = {}
req["persona_id"] = 0
req["description"] = "pytest"
response = client.post("/chat/create-chat-session", json=req)
chat_session_id = response.json()["chat_session_id"]
req = {}
req["chat_session_id"] = chat_session_id
req["message"] = "hello"
response = client.post("/chat/send-message-simple-api", json=req)
assert response.status_code == 200
@pytest.mark.skip(
reason="enable when we have a testing environment with preloaded data"
)
def test_handle_send_message_simple_with_history(client: TestClient) -> None:
req: dict[str, Any] = {}
messages = []
messages.append({"message": "What sorts of questions can you answer for me?"})
# messages.append({"message":
# "I'd be happy to assist you with a wide range of questions related to Ramp's expense management platform. "
# "I can help with topics such as:\n\n"
# "1. Setting up and managing your Ramp account\n"
# "2. Using Ramp cards and making purchases\n"
# "3. Submitting and reviewing expenses\n"
# "4. Understanding Ramp's features and benefits\n"
# "5. Navigating the Ramp dashboard and mobile app\n"
# "6. Managing team spending and budgets\n"
# "7. Integrating Ramp with accounting software\n"
# "8. Troubleshooting common issues\n\n"
# "Feel free to ask any specific questions you have about using Ramp, "
# "and I'll do my best to provide clear and helpful answers. "
# "Is there a particular area you'd like to know more about?",
# "role": "assistant"})
# req["prompt_id"] = 9
# req["persona_id"] = 6
# Yoda
req["persona_id"] = 1
req["prompt_id"] = 4
messages.append(
{
"message": "Answer questions for you, I can. "
"About many topics, knowledge I have. "
"But specific to documents provided, limited my responses are. "
"Ask you may about:\n\n"
"- User interviews and building trust with participants\n"
"- Designing effective surveys and survey questions \n"
"- Product analysis approaches\n"
"- Recruiting participants for research\n"
"- Discussion guides for user interviews\n"
"- Types of survey questions\n\n"
"More there may be, but focus on these areas, the given context does. "
"Specific questions you have, ask you should. Guide you I will, as best I can.",
"role": "assistant",
}
)
# messages.append({"message": "Where can I pilot a survey?"})
# messages.append({"message": "How many data points should I collect to validate my solution?"})
messages.append({"message": "What is solution validation research used for?"})
req["messages"] = messages
response = client.post("/chat/send-message-simple-with-history", json=req)
assert response.status_code == 200
resp_json = response.json()
# persona must have LLM relevance enabled for this to pass
assert len(resp_json["llm_chunks_indices"]) > 0