mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 18:00:29 +02:00
362 lines
14 KiB
Python
362 lines
14 KiB
Python
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from uuid import uuid4
|
|
|
|
from langchain.schema.messages import BaseMessage
|
|
from langchain_core.messages import AIMessageChunk
|
|
from langchain_core.messages import ToolCall
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.agents.agent_search.models import AgentSearchConfig
|
|
from onyx.agents.agent_search.run_graph import run_basic_graph
|
|
from onyx.agents.agent_search.run_graph import run_main_graph
|
|
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import AnswerPacket
|
|
from onyx.chat.models import AnswerStream
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import CitationInfo
|
|
from onyx.chat.models import OnyxAnswerPiece
|
|
from onyx.chat.models import PromptConfig
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
|
from onyx.chat.stream_processing.answer_response_handler import (
|
|
CitationResponseHandler,
|
|
)
|
|
from onyx.chat.stream_processing.utils import (
|
|
map_document_id_order,
|
|
)
|
|
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
|
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
|
from onyx.configs.constants import BASIC_KEY
|
|
from onyx.file_store.utils import InMemoryChatFile
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.llm.models import PreviousMessage
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.tool import Tool
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.tools.utils import explicit_tool_calling_supported
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class Answer:
|
|
def __init__(
|
|
self,
|
|
question: str,
|
|
answer_style_config: AnswerStyleConfig,
|
|
llm: LLM,
|
|
prompt_config: PromptConfig,
|
|
force_use_tool: ForceUseTool,
|
|
pro_search_config: AgentSearchConfig,
|
|
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
|
message_history: list[PreviousMessage] | None = None,
|
|
single_message_history: str | None = None,
|
|
# newly passed in files to include as part of this question
|
|
# TODO THIS NEEDS TO BE HANDLED
|
|
latest_query_files: list[InMemoryChatFile] | None = None,
|
|
tools: list[Tool] | None = None,
|
|
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
|
# but we only support them anyways
|
|
# if set to True, then never use the LLMs provided tool-calling functonality
|
|
skip_explicit_tool_calling: bool = False,
|
|
# Returns the full document sections text from the search tool
|
|
return_contexts: bool = False,
|
|
skip_gen_ai_answer_generation: bool = False,
|
|
is_connected: Callable[[], bool] | None = None,
|
|
fast_llm: LLM | None = None,
|
|
db_session: Session | None = None,
|
|
) -> None:
|
|
if single_message_history and message_history:
|
|
raise ValueError(
|
|
"Cannot provide both `message_history` and `single_message_history`"
|
|
)
|
|
|
|
self.question = question
|
|
self.is_connected: Callable[[], bool] | None = is_connected
|
|
|
|
self.latest_query_files = latest_query_files or []
|
|
|
|
self.tools = tools or []
|
|
self.force_use_tool = force_use_tool
|
|
|
|
self.message_history = message_history or []
|
|
# used for QA flow where we only want to send a single message
|
|
self.single_message_history = single_message_history
|
|
|
|
self.answer_style_config = answer_style_config
|
|
self.prompt_config = prompt_config
|
|
|
|
self.llm = llm
|
|
self.fast_llm = fast_llm
|
|
self.llm_tokenizer = get_tokenizer(
|
|
provider_type=llm.config.model_provider,
|
|
model_name=llm.config.model_name,
|
|
)
|
|
|
|
self._final_prompt: list[BaseMessage] | None = None
|
|
|
|
self._streamed_output: list[str] | None = None
|
|
self._processed_stream: (list[AnswerPacket] | None) = None
|
|
|
|
self._return_contexts = return_contexts
|
|
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
|
self._is_cancelled = False
|
|
|
|
self.using_tool_calling_llm = (
|
|
explicit_tool_calling_supported(
|
|
self.llm.config.model_provider, self.llm.config.model_name
|
|
)
|
|
and not skip_explicit_tool_calling
|
|
)
|
|
|
|
self.pro_search_config = pro_search_config
|
|
self.db_session = db_session
|
|
|
|
def _get_tools_list(self) -> list[Tool]:
|
|
if not self.force_use_tool.force_use:
|
|
return self.tools
|
|
|
|
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
|
|
|
|
args_str = (
|
|
f" with args='{self.force_use_tool.args}'"
|
|
if self.force_use_tool.args
|
|
else ""
|
|
)
|
|
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
|
return [tool]
|
|
|
|
# TODO: delete the function and move the full body to processed_streamed_output
|
|
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
|
current_llm_call = llm_calls[-1]
|
|
|
|
tool, tool_args = None, None
|
|
# handle the case where no decision has to be made; we simply run the tool
|
|
if (
|
|
current_llm_call.force_use_tool.force_use
|
|
and current_llm_call.force_use_tool.args is not None
|
|
):
|
|
tool_name, tool_args = (
|
|
current_llm_call.force_use_tool.tool_name,
|
|
current_llm_call.force_use_tool.args,
|
|
)
|
|
tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
|
|
|
# special pre-logic for non-tool calling LLM case
|
|
elif not self.using_tool_calling_llm and current_llm_call.tools:
|
|
chosen_tool_and_args = (
|
|
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
|
current_llm_call, self.llm
|
|
)
|
|
)
|
|
if chosen_tool_and_args:
|
|
tool, tool_args = chosen_tool_and_args
|
|
|
|
if tool and tool_args:
|
|
dummy_tool_call_chunk = AIMessageChunk(content="")
|
|
dummy_tool_call_chunk.tool_calls = [
|
|
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
|
]
|
|
|
|
response_handler_manager = LLMResponseHandlerManager(
|
|
ToolResponseHandler([tool]), None, self.is_cancelled
|
|
)
|
|
yield from response_handler_manager.handle_llm_response(
|
|
iter([dummy_tool_call_chunk])
|
|
)
|
|
|
|
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
|
if tmp_call is None:
|
|
return # no more LLM calls to process
|
|
current_llm_call = tmp_call
|
|
|
|
# if we're skipping gen ai answer generation, we should break
|
|
# out unless we're forcing a tool call. If we don't, we might generate an
|
|
# answer, which is a no-no!
|
|
if (
|
|
self.skip_gen_ai_answer_generation
|
|
and not current_llm_call.force_use_tool.force_use
|
|
):
|
|
return
|
|
|
|
# set up "handlers" to listen to the LLM response stream and
|
|
# feed back the processed results + handle tool call requests
|
|
# + figure out what the next LLM call should be
|
|
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
|
|
|
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
|
current_llm_call
|
|
) or ([], [])
|
|
|
|
# NEXT: we still want to handle the LLM response stream, but it is now:
|
|
# 1. handle the tool call requests
|
|
# 2. feed back the processed results
|
|
# 3. handle the citations
|
|
|
|
answer_handler = CitationResponseHandler(
|
|
context_docs=final_search_results,
|
|
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
|
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
|
)
|
|
|
|
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
|
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
|
response_handler_manager = LLMResponseHandlerManager(
|
|
tool_call_handler, answer_handler, self.is_cancelled
|
|
)
|
|
|
|
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
|
# is based on a flag in the pro search config
|
|
|
|
if self.pro_search_config.use_agentic_search:
|
|
if self.pro_search_config.search_request is None:
|
|
raise ValueError("Search request must be provided for pro search")
|
|
|
|
if self.db_session is None:
|
|
raise ValueError("db_session must be provided for pro search")
|
|
if self.fast_llm is None:
|
|
raise ValueError("fast_llm must be provided for pro search")
|
|
|
|
stream = run_main_graph(
|
|
config=self.pro_search_config,
|
|
)
|
|
else:
|
|
stream = run_basic_graph(
|
|
config=self.pro_search_config,
|
|
last_llm_call=current_llm_call,
|
|
response_handler_manager=response_handler_manager,
|
|
)
|
|
|
|
processed_stream = []
|
|
for packet in stream:
|
|
processed_stream.append(packet)
|
|
yield packet
|
|
self._processed_stream = processed_stream
|
|
return
|
|
# DEBUG: good breakpoint
|
|
stream = self.llm.stream(
|
|
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
|
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
|
prompt=current_llm_call.prompt_builder.build(),
|
|
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
|
tool_choice=(
|
|
"required"
|
|
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
|
else None
|
|
),
|
|
structured_response_format=self.answer_style_config.structured_response_format,
|
|
)
|
|
yield from response_handler_manager.handle_llm_response(stream)
|
|
|
|
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
|
if new_llm_call:
|
|
yield from self._get_response(llm_calls + [new_llm_call])
|
|
|
|
@property
|
|
def processed_streamed_output(self) -> AnswerStream:
|
|
if self._processed_stream is not None:
|
|
yield from self._processed_stream
|
|
return
|
|
|
|
prompt_builder = AnswerPromptBuilder(
|
|
user_message=default_build_user_message(
|
|
user_query=self.question,
|
|
prompt_config=self.prompt_config,
|
|
files=self.latest_query_files,
|
|
single_message_history=self.single_message_history,
|
|
),
|
|
message_history=self.message_history,
|
|
llm_config=self.llm.config,
|
|
raw_user_query=self.question,
|
|
raw_user_uploaded_files=self.latest_query_files or [],
|
|
single_message_history=self.single_message_history,
|
|
)
|
|
prompt_builder.update_system_prompt(
|
|
default_build_system_message(self.prompt_config)
|
|
)
|
|
llm_call = LLMCall(
|
|
prompt_builder=prompt_builder,
|
|
tools=self._get_tools_list(),
|
|
force_use_tool=self.force_use_tool,
|
|
files=self.latest_query_files,
|
|
tool_call_info=[],
|
|
using_tool_calling_llm=self.using_tool_calling_llm,
|
|
)
|
|
|
|
processed_stream = []
|
|
for processed_packet in self._get_response([llm_call]):
|
|
processed_stream.append(processed_packet)
|
|
yield processed_packet
|
|
|
|
self._processed_stream = processed_stream
|
|
|
|
@property
|
|
def llm_answer(self) -> str:
|
|
answer = ""
|
|
for packet in self.processed_streamed_output:
|
|
# handle basic answer flow, plus level 0 agent answer flow
|
|
# since level 0 is the first answer the user sees and therefore the
|
|
# child message of the user message in the db (so it is handled
|
|
# like a basic flow answer)
|
|
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
|
|
isinstance(packet, AgentAnswerPiece)
|
|
and packet.answer_piece
|
|
and packet.answer_type == "agent_level_answer"
|
|
and packet.level == 0
|
|
):
|
|
answer += packet.answer_piece
|
|
|
|
return answer
|
|
|
|
def llm_answer_by_level(self) -> dict[int, str]:
|
|
answer_by_level: dict[int, str] = defaultdict(str)
|
|
for packet in self.processed_streamed_output:
|
|
if (
|
|
isinstance(packet, AgentAnswerPiece)
|
|
and packet.answer_piece
|
|
and packet.answer_type == "agent_level_answer"
|
|
):
|
|
answer_by_level[packet.level] += packet.answer_piece
|
|
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
|
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
|
|
return answer_by_level
|
|
|
|
@property
|
|
def citations(self) -> list[CitationInfo]:
|
|
citations: list[CitationInfo] = []
|
|
for packet in self.processed_streamed_output:
|
|
if isinstance(packet, CitationInfo) and packet.level is None:
|
|
citations.append(packet)
|
|
|
|
return citations
|
|
|
|
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
|
|
citations_by_subquestion: dict[
|
|
tuple[int, int], list[CitationInfo]
|
|
] = defaultdict(list)
|
|
for packet in self.processed_streamed_output:
|
|
if isinstance(packet, CitationInfo):
|
|
if packet.level_question_nr is not None and packet.level is not None:
|
|
citations_by_subquestion[
|
|
(packet.level, packet.level_question_nr)
|
|
].append(packet)
|
|
elif packet.level is None:
|
|
citations_by_subquestion[BASIC_KEY].append(packet)
|
|
return citations_by_subquestion
|
|
|
|
def is_cancelled(self) -> bool:
|
|
if self._is_cancelled:
|
|
return True
|
|
|
|
if self.is_connected is not None:
|
|
if not self.is_connected():
|
|
logger.debug("Answer stream has been cancelled")
|
|
self._is_cancelled = not self.is_connected()
|
|
|
|
return self._is_cancelled
|