AgentPromptConfig in Answer class

This commit is contained in:
Evan Lohn
2025-01-29 10:54:32 -08:00
parent efa32a8c04
commit 1a22af4f27
5 changed files with 97 additions and 96 deletions

View File

@@ -1,7 +1,7 @@
from collections import defaultdict
from collections.abc import Callable
from uuid import UUID
from langchain.schema.messages import BaseMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
@@ -13,17 +13,18 @@ from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
@@ -33,15 +34,14 @@ logger = setup_logger()
class Answer:
def __init__(
self,
question: str,
prompt_builder: AnswerPromptBuilder,
answer_style_config: AnswerStyleConfig,
llm: LLM,
prompt_config: PromptConfig,
fast_llm: LLM,
force_use_tool: ForceUseTool,
agent_search_config: AgentSearchConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
search_request: SearchRequest,
chat_session_id: UUID,
current_agent_message_id: int,
# newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None,
@@ -52,28 +52,18 @@ class Answer:
skip_explicit_tool_calling: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
fast_llm: LLM | None = None,
db_session: Session | None = None,
use_agentic_search: bool = False,
) -> None:
if single_message_history and message_history:
raise ValueError(
"Cannot provide both `message_history` and `single_message_history`"
)
self.question = question
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.tools = tools or []
self.force_use_tool = force_use_tool
self.message_history = message_history or []
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
self.llm = llm
self.fast_llm = fast_llm
@@ -82,8 +72,6 @@ class Answer:
model_name=llm.config.model_name,
)
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: (list[AnswerPacket] | None) = None
@@ -97,6 +85,42 @@ class Answer:
and not skip_explicit_tool_calling
)
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=force_use_tool,
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)
self.agent_search_config = agent_search_config
self.db_session = db_session

View File

@@ -284,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
into the `PromptBuilder` object."""
system_prompt: str
task_prompt: str

View File

@@ -8,7 +8,6 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
@@ -134,7 +133,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -712,6 +710,7 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values():
tools.extend(tool_list)
# TODO: unify message history with single message history
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
@@ -739,26 +738,7 @@ def stream_chat_message_objects(
else None
),
)
# TODO: Since we're deleting the current main path in Answer,
# we should construct this unconditionally inside Answer instead
# Leaving it here for the time being to avoid breaking changes
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if new_msg_req.use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
@@ -776,35 +756,12 @@ def stream_chat_message_objects(
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=force_use_tool,
use_agentic_search=new_msg_req.use_agentic_search,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=new_msg_req.structured_response_format,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
# TODO: add previous messages, answer style config, tools, etc.
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=(
llm
or get_main_llm_from_tuple(
@@ -818,12 +775,13 @@ def stream_chat_message_objects(
)
),
fast_llm=fast_llm,
message_history=message_history,
tools=tools,
force_use_tool=force_use_tool,
single_message_history=single_message_history,
agent_search_config=agent_search_config,
search_request=search_request,
chat_session_id=chat_session_id,
current_agent_message_id=reserved_message_id,
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
)
# reference_db_search_docs = None