import re from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest from ee.onyx.server.query_and_chat.models import ( BasicCreateChatMessageWithHistoryRequest, ) from ee.onyx.server.query_and_chat.models import ChatBasicResponse from ee.onyx.server.query_and_chat.models import SimpleDoc from onyx.auth.users import current_user from onyx.chat.chat_utils import combine_message_thread from onyx.chat.chat_utils import create_chat_chain from onyx.chat.models import AllCitations from onyx.chat.models import FinalUsedContextDocsResponse from onyx.chat.models import LlmDoc from onyx.chat.models import LLMRelevanceFilterResponse from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import QADocsResponse from onyx.chat.models import StreamingError from onyx.chat.process_message import ChatPacketStream from onyx.chat.process_message import stream_chat_message_objects from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from onyx.configs.constants import MessageType from onyx.context.search.models import OptionalSearchSetting from onyx.context.search.models import RetrievalDetails from onyx.context.search.models import SavedSearchDoc from onyx.db.chat import create_chat_session from onyx.db.chat import create_new_chat_message from onyx.db.chat import get_or_create_root_message from onyx.db.engine import get_session from onyx.db.models import User from onyx.llm.factory import get_llms_for_persona from onyx.llm.utils import get_max_input_tokens from onyx.natural_language_processing.utils import get_tokenizer from onyx.secondary_llm_flows.query_expansion import thread_based_query_rephrase from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.server.query_and_chat.models import CreateChatMessageRequest from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter(prefix="/chat") def _translate_doc_response_to_simple_doc( doc_response: QADocsResponse, ) -> list[SimpleDoc]: return [ SimpleDoc( id=doc.document_id, semantic_identifier=doc.semantic_identifier, link=doc.link, blurb=doc.blurb, match_highlights=[ highlight for highlight in doc.match_highlights if highlight ], source_type=doc.source_type, metadata=doc.metadata, ) for doc in doc_response.top_documents ] def _get_final_context_doc_indices( final_context_docs: list[LlmDoc] | None, top_docs: list[SavedSearchDoc] | None, ) -> list[int] | None: """ this function returns a list of indices of the simple search docs that were actually fed to the LLM. """ if final_context_docs is None or top_docs is None: return None final_context_doc_ids = {doc.document_id for doc in final_context_docs} return [ i for i, doc in enumerate(top_docs) if doc.document_id in final_context_doc_ids ] def _convert_packet_stream_to_response( packets: ChatPacketStream, ) -> ChatBasicResponse: response = ChatBasicResponse() final_context_docs: list[LlmDoc] = [] answer = "" for packet in packets: if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: answer += packet.answer_piece elif isinstance(packet, QADocsResponse): response.top_documents = packet.top_documents # TODO: deprecate `simple_search_docs` response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): response.message_id = packet.message_id elif isinstance(packet, LLMRelevanceFilterResponse): response.llm_selected_doc_indices = packet.llm_selected_doc_indices # TODO: deprecate `llm_chunks_indices` response.llm_chunks_indices = packet.llm_selected_doc_indices elif isinstance(packet, FinalUsedContextDocsResponse): final_context_docs = packet.final_context_docs elif isinstance(packet, AllCitations): response.cited_documents = { citation.citation_num: citation.document_id for citation in packet.citations } response.final_context_doc_indices = _get_final_context_doc_indices( final_context_docs, response.top_documents ) response.answer = answer if answer: response.answer_citationless = remove_answer_citations(answer) return response def remove_answer_citations(answer: str) -> str: pattern = r"\s*\[\[\d+\]\]\(http[s]?://[^\s]+\)" return re.sub(pattern, "", answer) @router.post("/send-message-simple-api") def handle_simplified_chat_message( chat_message_req: BasicCreateChatMessageRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChatBasicResponse: """This is a Non-Streaming version that only gives back a minimal set of information""" logger.notice(f"Received new simple api chat message: {chat_message_req.message}") if not chat_message_req.message: raise HTTPException(status_code=400, detail="Empty chat message is invalid") try: parent_message, _ = create_chat_chain( chat_session_id=chat_message_req.chat_session_id, db_session=db_session ) except Exception: parent_message = get_or_create_root_message( chat_session_id=chat_message_req.chat_session_id, db_session=db_session ) if ( chat_message_req.retrieval_options is None and chat_message_req.search_doc_ids is None ): retrieval_options: RetrievalDetails | None = RetrievalDetails( run_search=OptionalSearchSetting.ALWAYS, real_time=False, ) else: retrieval_options = chat_message_req.retrieval_options full_chat_msg_info = CreateChatMessageRequest( chat_session_id=chat_message_req.chat_session_id, parent_message_id=parent_message.id, message=chat_message_req.message, file_descriptors=[], prompt_id=None, search_doc_ids=chat_message_req.search_doc_ids, retrieval_options=retrieval_options, # Simple API does not support reranking, hide complexity from user rerank_settings=None, query_override=chat_message_req.query_override, # Currently only applies to search flow not chat chunks_above=0, chunks_below=0, full_doc=chat_message_req.full_doc, structured_response_format=chat_message_req.structured_response_format, ) packets = stream_chat_message_objects( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets) @router.post("/send-message-simple-with-history") def handle_send_message_simple_with_history( req: BasicCreateChatMessageWithHistoryRequest, user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChatBasicResponse: """This is a Non-Streaming version that only gives back a minimal set of information. takes in chat history maintained by the caller and does query rephrasing similar to answer-with-quote""" if len(req.messages) == 0: raise HTTPException(status_code=400, detail="Messages cannot be zero length") # This is a sanity check to make sure the chat history is valid # It must start with a user message and alternate beteen user and assistant expected_role = MessageType.USER for msg in req.messages: if not msg.message: raise HTTPException( status_code=400, detail="One or more chat messages were empty" ) if msg.role != expected_role: raise HTTPException( status_code=400, detail="Message roles must start and end with MessageType.USER and alternate in-between.", ) if expected_role == MessageType.USER: expected_role = MessageType.ASSISTANT else: expected_role = MessageType.USER query = req.messages[-1].message msg_history = req.messages[:-1] logger.notice(f"Received new simple with history chat message: {query}") user_id = user.id if user is not None else None chat_session = create_chat_session( db_session=db_session, description="handle_send_message_simple_with_history", user_id=user_id, persona_id=req.persona_id, ) llm, _ = get_llms_for_persona(persona=chat_session.persona) llm_tokenizer = get_tokenizer( model_name=llm.config.model_name, provider_type=llm.config.model_provider, ) input_tokens = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider ) max_history_tokens = int(input_tokens * CHAT_TARGET_CHUNK_PERCENTAGE) # Every chat Session begins with an empty root message root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session ) chat_message = root_message for msg in msg_history: chat_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=chat_message, prompt_id=req.prompt_id, message=msg.message, token_count=len(llm_tokenizer.encode(msg.message)), message_type=msg.role, db_session=db_session, commit=False, ) db_session.commit() history_str = combine_message_thread( messages=msg_history, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer, ) rephrased_query = req.query_override or thread_based_query_rephrase( user_query=query, history_str=history_str, ) if req.retrieval_options is None and req.search_doc_ids is None: retrieval_options: RetrievalDetails | None = RetrievalDetails( run_search=OptionalSearchSetting.ALWAYS, real_time=False, ) else: retrieval_options = req.retrieval_options full_chat_msg_info = CreateChatMessageRequest( chat_session_id=chat_session.id, parent_message_id=chat_message.id, message=query, file_descriptors=[], prompt_id=req.prompt_id, search_doc_ids=req.search_doc_ids, retrieval_options=retrieval_options, # Simple API does not support reranking, hide complexity from user rerank_settings=None, query_override=rephrased_query, chunks_above=0, chunks_below=0, full_doc=req.full_doc, structured_response_format=req.structured_response_format, ) packets = stream_chat_message_objects( new_msg_req=full_chat_msg_info, user=user, db_session=db_session, enforce_chat_session_id_for_search_docs=False, ) return _convert_packet_stream_to_response(packets)