2024-12-13 09:56:10 -08:00

314 lines
11 KiB
Python

import re
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AllCitations
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.constants import MessageType
from onyx.context.search.models import OptionalSearchSetting
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.utils import get_max_input_tokens
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
) -> list[int] | None:
"""
this function returns a list of indices of the simple search docs
that were actually fed to the LLM.
"""
if final_context_docs is None or top_docs is None:
return None
final_context_doc_ids = {doc.document_id for doc in final_context_docs}
return [
i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids
]
def _convert_packet_stream_to_response(
packets: ChatPacketStream,
) -> ChatBasicResponse:
response = ChatBasicResponse()
final_context_docs: list[LlmDoc] = []
answer = ""
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
response.message_id = packet.message_id
elif isinstance(packet, LLMRelevanceFilterResponse):
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
# TODO: deprecate `llm_chunks_indices`
response.llm_chunks_indices = packet.llm_selected_doc_indices
elif isinstance(packet, FinalUsedContextDocsResponse):
final_context_docs = packet.final_context_docs
elif isinstance(packet, AllCitations):
response.cited_documents = {
citation.citation_num: citation.document_id
for citation in packet.citations
}
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)
return response
def remove_answer_citations(answer: str) -> str:
pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)"
return re.sub(pattern, "", answer)
@router.post("/send-message-simple-api")
def handle_simplified_chat_message(
chat_message_req: BasicCreateChatMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information"""
logger.notice(f"Received new simple api chat message: {chat_message_req.message}")
if not chat_message_req.message:
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
try:
parent_message, _ = create_chat_chain(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
except Exception:
parent_message = get_or_create_root_message(
chat_session_id=chat_message_req.chat_session_id, db_session=db_session
)
if (
chat_message_req.retrieval_options is None
and chat_message_req.search_doc_ids is None
):
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = chat_message_req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_message_req.chat_session_id,
parent_message_id=parent_message.id,
message=chat_message_req.message,
file_descriptors=[],
prompt_id=None,
search_doc_ids=chat_message_req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=chat_message_req.query_override,
# Currently only applies to search flow not chat
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)
@router.post("/send-message-simple-with-history")
def handle_send_message_simple_with_history(
req: BasicCreateChatMessageWithHistoryRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatBasicResponse:
"""This is a Non-Streaming version that only gives back a minimal set of information.
takes in chat history maintained by the caller
and does query rephrasing similar to answer-with-quote"""
if len(req.messages) == 0:
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
# This is a sanity check to make sure the chat history is valid
# It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
raise HTTPException(
status_code=400, detail="One or more chat messages were empty"
)
if msg.role != expected_role:
raise HTTPException(
status_code=400,
detail="Message roles must start and end with MessageType.USER and alternate in-between.",
)
if expected_role == MessageType.USER:
expected_role = MessageType.ASSISTANT
else:
expected_role = MessageType.USER
query = req.messages[-1].message
msg_history = req.messages[:-1]
logger.notice(f"Received new simple with history chat message: {query}")
user_id = user.id if user is not None else None
chat_session = create_chat_session(
db_session=db_session,
description="handle_send_message_simple_with_history",
user_id=user_id,
persona_id=req.persona_id,
)
llm, _ = get_llms_for_persona(persona=chat_session.persona)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
)
max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
chat_session_id=chat_session.id, db_session=db_session
)
chat_message = root_message
for msg in msg_history:
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=chat_message,
prompt_id=req.prompt_id,
message=msg.message,
token_count=len(llm_tokenizer.encode(msg.message)),
message_type=msg.role,
db_session=db_session,
commit=False,
)
db_session.commit()
history_str = combine_message_thread(
messages=msg_history,
max_tokens=max_history_tokens,
llm_tokenizer=llm_tokenizer,
)
rephrased_query = req.query_override or thread_based_query_rephrase(
user_query=query,
history_str=history_str,
)
if req.retrieval_options is None and req.search_doc_ids is None:
retrieval_options: RetrievalDetails | None = RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=False,
)
else:
retrieval_options = req.retrieval_options
full_chat_msg_info = CreateChatMessageRequest(
chat_session_id=chat_session.id,
parent_message_id=chat_message.id,
message=query,
file_descriptors=[],
prompt_id=req.prompt_id,
search_doc_ids=req.search_doc_ids,
retrieval_options=retrieval_options,
# Simple API does not support reranking, hide complexity from user
rerank_settings=None,
query_override=rephrased_query,
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
)
packets = stream_chat_message_objects(
new_msg_req=full_chat_msg_info,
user=user,
db_session=db_session,
enforce_chat_session_id_for_search_docs=False,
)
return _convert_packet_stream_to_response(packets)