mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-29 19:23:02 +01:00
231 lines
8.9 KiB
Python
231 lines
8.9 KiB
Python
from collections import defaultdict
|
|
from collections.abc import Callable
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.agents.agent_search.models import GraphConfig
|
|
from onyx.agents.agent_search.models import GraphInputs
|
|
from onyx.agents.agent_search.models import GraphPersistence
|
|
from onyx.agents.agent_search.models import GraphSearchConfig
|
|
from onyx.agents.agent_search.models import GraphTooling
|
|
from onyx.agents.agent_search.run_graph import run_basic_graph
|
|
from onyx.agents.agent_search.run_graph import run_main_graph
|
|
from onyx.chat.models import AgentAnswerPiece
|
|
from onyx.chat.models import AnswerPacket
|
|
from onyx.chat.models import AnswerStream
|
|
from onyx.chat.models import AnswerStyleConfig
|
|
from onyx.chat.models import CitationInfo
|
|
from onyx.chat.models import OnyxAnswerPiece
|
|
from onyx.chat.models import StreamStopInfo
|
|
from onyx.chat.models import StreamStopReason
|
|
from onyx.chat.models import SubQuestionKey
|
|
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
|
from onyx.configs.constants import BASIC_KEY
|
|
from onyx.context.search.models import SearchRequest
|
|
from onyx.file_store.utils import InMemoryChatFile
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.tools.force import ForceUseTool
|
|
from onyx.tools.tool import Tool
|
|
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.tools.utils import explicit_tool_calling_supported
|
|
from onyx.utils.gpu_utils import gpu_status_request
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
|
|
|
|
|
class Answer:
|
|
def __init__(
|
|
self,
|
|
prompt_builder: AnswerPromptBuilder,
|
|
answer_style_config: AnswerStyleConfig,
|
|
llm: LLM,
|
|
fast_llm: LLM,
|
|
force_use_tool: ForceUseTool,
|
|
search_request: SearchRequest,
|
|
chat_session_id: UUID,
|
|
current_agent_message_id: int,
|
|
db_session: Session,
|
|
# newly passed in files to include as part of this question
|
|
# TODO THIS NEEDS TO BE HANDLED
|
|
latest_query_files: list[InMemoryChatFile] | None = None,
|
|
tools: list[Tool] | None = None,
|
|
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
|
# but we only support them anyways
|
|
# if set to True, then never use the LLMs provided tool-calling functonality
|
|
skip_explicit_tool_calling: bool = False,
|
|
skip_gen_ai_answer_generation: bool = False,
|
|
is_connected: Callable[[], bool] | None = None,
|
|
use_agentic_search: bool = False,
|
|
) -> None:
|
|
self.is_connected: Callable[[], bool] | None = is_connected
|
|
self._processed_stream: (list[AnswerPacket] | None) = None
|
|
self._is_cancelled = False
|
|
|
|
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
|
search_tool: SearchTool | None = None
|
|
|
|
if len(search_tools) > 1:
|
|
# TODO: handle multiple search tools
|
|
raise ValueError("Multiple search tools found")
|
|
elif len(search_tools) == 1:
|
|
search_tool = search_tools[0]
|
|
|
|
using_tool_calling_llm = (
|
|
explicit_tool_calling_supported(
|
|
llm.config.model_provider, llm.config.model_name
|
|
)
|
|
and not skip_explicit_tool_calling
|
|
)
|
|
|
|
rerank_settings = search_request.rerank_settings
|
|
|
|
using_cloud_reranking = (
|
|
rerank_settings is not None
|
|
and rerank_settings.rerank_provider_type is not None
|
|
)
|
|
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
|
|
|
# TODO: this is a hack to force the query to be used for the search tool
|
|
# this should be removed once we fully unify graph inputs (i.e.
|
|
# remove SearchQuery entirely)
|
|
if (
|
|
force_use_tool.force_use
|
|
and search_tool
|
|
and force_use_tool.args
|
|
and force_use_tool.tool_name == search_tool.name
|
|
and QUERY_FIELD in force_use_tool.args
|
|
):
|
|
search_request.query = force_use_tool.args[QUERY_FIELD]
|
|
|
|
self.graph_inputs = GraphInputs(
|
|
search_request=search_request,
|
|
prompt_builder=prompt_builder,
|
|
files=latest_query_files,
|
|
structured_response_format=answer_style_config.structured_response_format,
|
|
)
|
|
self.graph_tooling = GraphTooling(
|
|
primary_llm=llm,
|
|
fast_llm=fast_llm,
|
|
search_tool=search_tool,
|
|
tools=tools or [],
|
|
force_use_tool=force_use_tool,
|
|
using_tool_calling_llm=using_tool_calling_llm,
|
|
)
|
|
self.graph_persistence = GraphPersistence(
|
|
db_session=db_session,
|
|
chat_session_id=chat_session_id,
|
|
message_id=current_agent_message_id,
|
|
)
|
|
self.search_behavior_config = GraphSearchConfig(
|
|
use_agentic_search=use_agentic_search,
|
|
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
|
allow_refinement=True,
|
|
allow_agent_reranking=allow_agent_reranking,
|
|
)
|
|
self.graph_config = GraphConfig(
|
|
inputs=self.graph_inputs,
|
|
tooling=self.graph_tooling,
|
|
persistence=self.graph_persistence,
|
|
behavior=self.search_behavior_config,
|
|
)
|
|
|
|
@property
|
|
def processed_streamed_output(self) -> AnswerStream:
|
|
if self._processed_stream is not None:
|
|
yield from self._processed_stream
|
|
return
|
|
|
|
run_langgraph = (
|
|
run_main_graph
|
|
if self.graph_config.behavior.use_agentic_search
|
|
else run_basic_graph
|
|
)
|
|
stream = run_langgraph(
|
|
self.graph_config,
|
|
)
|
|
|
|
processed_stream = []
|
|
for packet in stream:
|
|
if self.is_cancelled():
|
|
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
|
yield packet
|
|
break
|
|
processed_stream.append(packet)
|
|
yield packet
|
|
|
|
self._processed_stream = processed_stream
|
|
|
|
@property
|
|
def llm_answer(self) -> str:
|
|
answer = ""
|
|
for packet in self.processed_streamed_output:
|
|
# handle basic answer flow, plus level 0 agent answer flow
|
|
# since level 0 is the first answer the user sees and therefore the
|
|
# child message of the user message in the db (so it is handled
|
|
# like a basic flow answer)
|
|
if (isinstance(packet, OnyxAnswerPiece) and packet.answer_piece) or (
|
|
isinstance(packet, AgentAnswerPiece)
|
|
and packet.answer_piece
|
|
and packet.answer_type == "agent_level_answer"
|
|
and packet.level == 0
|
|
):
|
|
answer += packet.answer_piece
|
|
|
|
return answer
|
|
|
|
def llm_answer_by_level(self) -> dict[int, str]:
|
|
answer_by_level: dict[int, str] = defaultdict(str)
|
|
for packet in self.processed_streamed_output:
|
|
if (
|
|
isinstance(packet, AgentAnswerPiece)
|
|
and packet.answer_piece
|
|
and packet.answer_type == "agent_level_answer"
|
|
):
|
|
assert packet.level is not None
|
|
answer_by_level[packet.level] += packet.answer_piece
|
|
elif isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
|
answer_by_level[BASIC_KEY[0]] += packet.answer_piece
|
|
return answer_by_level
|
|
|
|
@property
|
|
def citations(self) -> list[CitationInfo]:
|
|
citations: list[CitationInfo] = []
|
|
for packet in self.processed_streamed_output:
|
|
if isinstance(packet, CitationInfo) and packet.level is None:
|
|
citations.append(packet)
|
|
|
|
return citations
|
|
|
|
def citations_by_subquestion(self) -> dict[SubQuestionKey, list[CitationInfo]]:
|
|
citations_by_subquestion: dict[
|
|
SubQuestionKey, list[CitationInfo]
|
|
] = defaultdict(list)
|
|
basic_subq_key = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
|
for packet in self.processed_streamed_output:
|
|
if isinstance(packet, CitationInfo):
|
|
if packet.level_question_num is not None and packet.level is not None:
|
|
citations_by_subquestion[
|
|
SubQuestionKey(
|
|
level=packet.level, question_num=packet.level_question_num
|
|
)
|
|
].append(packet)
|
|
elif packet.level is None:
|
|
citations_by_subquestion[basic_subq_key].append(packet)
|
|
return citations_by_subquestion
|
|
|
|
def is_cancelled(self) -> bool:
|
|
if self._is_cancelled:
|
|
return True
|
|
|
|
if self.is_connected is not None:
|
|
if not self.is_connected():
|
|
logger.debug("Answer stream has been cancelled")
|
|
self._is_cancelled = not self.is_connected()
|
|
|
|
return self._is_cancelled
|