mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
AgentPromptConfig in Answer class
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
@@ -13,17 +13,18 @@ from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -33,15 +34,14 @@ logger = setup_logger()
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
question: str,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
fast_llm: LLM,
|
||||
force_use_tool: ForceUseTool,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
search_request: SearchRequest,
|
||||
chat_session_id: UUID,
|
||||
current_agent_message_id: int,
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
@@ -52,28 +52,18 @@ class Answer:
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
fast_llm: LLM | None = None,
|
||||
db_session: Session | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
"Cannot provide both `message_history` and `single_message_history`"
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
@@ -82,8 +72,6 @@ class Answer:
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
@@ -97,6 +85,42 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
logger.warning("Multiple search tools found, using first one")
|
||||
search_tool = search_tools[0]
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
else:
|
||||
logger.warning("No search tool found")
|
||||
if use_agentic_search:
|
||||
raise ValueError("No search tool found, cannot use agentic search")
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
self.agent_search_config = agent_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
|
@@ -284,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
into the `PromptBuilder` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
|
@@ -8,7 +8,6 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -134,7 +133,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -712,6 +710,7 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
@@ -739,26 +738,7 @@ def stream_chat_message_objects(
|
||||
else None
|
||||
),
|
||||
)
|
||||
# TODO: Since we're deleting the current main path in Answer,
|
||||
# we should construct this unconditionally inside Answer instead
|
||||
# Leaving it here for the time being to avoid breaking changes
|
||||
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
logger.warning("Multiple search tools found, using first one")
|
||||
search_tool = search_tools[0]
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
else:
|
||||
logger.warning("No search tool found")
|
||||
if new_msg_req.use_agentic_search:
|
||||
raise ValueError("No search tool found, cannot use agentic search")
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
@@ -776,35 +756,12 @@ def stream_chat_message_objects(
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_main_llm_from_tuple(
|
||||
@@ -818,12 +775,13 @@ def stream_chat_message_objects(
|
||||
)
|
||||
),
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=force_use_tool,
|
||||
single_message_history=single_message_history,
|
||||
agent_search_config=agent_search_config,
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
current_agent_message_id=reserved_message_id,
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
# reference_db_search_docs = None
|
||||
|
Reference in New Issue
Block a user