danswer/backend/danswer/chat/process_message.py
2024-07-26 16:43:15 -07:00

755 lines
31 KiB
Python

from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
from danswer.file_store.utils import load_all_chat_files
from danswer.file_store.utils import save_files_from_urls
from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
from danswer.search.utils import drop_llm_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
from danswer.tools.internet_search.internet_search_tool import (
internet_search_response_to_search_docs,
)
from danswer.tools.internet_search.internet_search_tool import InternetSearchResponse
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
def translate_citations(
citations_list: list[CitationInfo], db_docs: list[DbSearchDoc]
) -> dict[int, int]:
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
for db_doc in db_docs:
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
citation_to_saved_doc_id_map: dict[int, int] = {}
for citation in citations_list:
if citation.citation_num not in citation_to_saved_doc_id_map:
citation_to_saved_doc_id_map[
citation.citation_num
] = doc_id_to_saved_doc_id_map[citation.document_id]
return citation_to_saved_doc_id_map
def _handle_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
selected_search_docs: list[DbSearchDoc] | None,
dedupe_docs: bool = False,
) -> tuple[QADocsResponse, list[DbSearchDoc], list[int] | None]:
response_sumary = cast(SearchResponseSummary, packet.response)
dropped_inds = None
if not selected_search_docs:
top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections)
deduped_docs = top_docs
if dedupe_docs:
deduped_docs, dropped_inds = dedupe_documents(top_docs)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in deduped_docs
]
else:
reference_db_search_docs = selected_search_docs
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=response_sumary.rephrased_query,
top_documents=response_docs,
predicted_flow=response_sumary.predicted_flow,
predicted_search=response_sumary.predicted_search,
applied_source_filters=response_sumary.final_filters.source_type,
applied_time_cutoff=response_sumary.final_filters.time_cutoff,
recency_bias_multiplier=response_sumary.recency_bias_multiplier,
),
reference_db_search_docs,
dropped_inds,
)
def _handle_internet_search_tool_response_summary(
packet: ToolResponse,
db_session: Session,
) -> tuple[QADocsResponse, list[DbSearchDoc]]:
internet_search_response = cast(InternetSearchResponse, packet.response)
server_search_docs = internet_search_response_to_search_docs(
internet_search_response
)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=doc, db_session=db_session)
for doc in server_search_docs
]
response_docs = [
translate_db_search_doc_to_server_search_doc(db_search_doc)
for db_search_doc in reference_db_search_docs
]
return (
QADocsResponse(
rephrased_query=internet_search_response.revised_query,
top_documents=response_docs,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.HYBRID,
applied_source_filters=[],
applied_time_cutoff=None,
recency_bias_multiplier=1.0,
),
reference_db_search_docs,
)
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
)
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
else None
)
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
should_force_search = any(
[
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
DISABLE_LLM_CHOOSE_SEARCH,
]
)
if should_force_search:
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
ChatPacket = (
StreamingError
| QADocsResponse
| LLMRelevanceFilterResponse
| ChatMessageDetail
| DanswerAnswerPiece
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
)
ChatPacketStream = Iterator[ChatPacket]
def stream_chat_message_objects(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
try:
user_id = user.id if user is not None else None
chat_session = get_chat_session_by_id(
chat_session_id=new_msg_req.chat_session_id,
user_id=user_id,
db_session=db_session,
)
message_text = new_msg_req.message
chat_session_id = new_msg_req.chat_session_id
parent_id = new_msg_req.parent_message_id
reference_doc_ids = new_msg_req.search_doc_ids
retrieval_options = new_msg_req.retrieval_options
alternate_assistant_id = new_msg_req.alternate_assistant_id
# use alternate persona if alternative assistant id is passed in
if alternate_assistant_id is not None:
persona = get_persona_by_id(
alternate_assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
else:
persona = chat_session.persona
prompt_id = new_msg_req.prompt_id
if prompt_id is None and persona.prompts:
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
if reference_doc_ids is None and retrieval_options is None:
raise RuntimeError(
"Must specify a set of documents for chat or specify search options"
)
try:
llm, fast_llm = get_llms_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_tokenizer = get_default_llm_tokenizer()
llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message
root_message = get_or_create_root_message(
chat_session_id=chat_session_id, db_session=db_session
)
if parent_id is not None:
parent_message = get_chat_message(
chat_message_id=parent_id,
user_id=user_id,
db_session=db_session,
)
else:
parent_message = root_message
user_message = None
if not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=parent_message,
prompt_id=prompt_id,
message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
files=None, # Need to attach later for optimization to only load files once in parallel
db_session=db_session,
commit=False,
)
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.id != user_message.id:
db_session.rollback()
raise RuntimeError(
"The new message was not on the mainline. "
"Be sure to update the chat pointers before calling this."
)
# NOTE: do not commit user message - it will be committed when the
# assistant message is successfully generated
else:
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
# leads to worst search quality
if not history_msgs:
new_msg_req.query_override = (
new_msg_req.query_override or new_msg_req.message
)
# load all files needed for this chat chain in memory
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
latest_query_files = [
file
for file in files
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
]
if user_message:
attach_files_to_chat_message(
chat_message=user_message,
files=[
new_file.to_file_descriptor() for new_file in latest_query_files
],
db_session=db_session,
commit=False,
)
selected_db_search_docs = None
selected_sections: list[InferenceSection] | None = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
)
# Generates full documents currently
# May extend to use sections instead in the future
selected_sections = inference_sections_from_ids(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True
)
# In case the search doc is deleted, just don't include it
# though this should never happen
db_search_docs_or_none = [
get_db_search_doc_by_id(doc_id=doc_id, db_session=db_session)
for doc_id in reference_doc_ids
]
selected_db_search_docs = [
db_sd for db_sd in db_search_docs_or_none if db_sd
]
else:
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
persona.num_chunks
if persona.num_chunks is not None
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
)
if not final_msg.prompt:
raise RuntimeError("No Prompt found")
prompt_config = (
PromptConfig.from_model(
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
)
if not persona
else PromptConfig.from_model(persona.prompts[0])
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(api_key=bing_api_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
citation_config=CitationConfig(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
),
prompt_config=prompt_config,
llm=(
llm
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
additional_headers=litellm_additional_headers,
)
)
),
message_history=[
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
],
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
)
reference_db_search_docs = None
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
if reference_db_search_docs is not None and dropped_indices:
chunk_indices = drop_llm_indices(
llm_indices=chunk_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
relevant_chunk_indices=chunk_indices
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
except Exception as e:
logger.exception("Failed to process chat message")
# Don't leak the API key
error_msg = str(e)
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
error_msg = (
f"LLM failed to respond. Invalid API "
f"key error from '{llm.config.model_provider}'."
)
yield StreamingError(error=error_msg)
# Cancel the transaction so that no messages are saved
db_session.rollback()
return
# Post-LLM answer processing
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
except Exception as e:
logger.exception(e)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@log_generator_function_time()
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
)
for obj in objects:
yield get_json_line(obj.dict())