AgentPromptConfig in Answer class

This commit is contained in:
Evan Lohn
2025-01-29 10:54:32 -08:00
parent efa32a8c04
commit 1a22af4f27
5 changed files with 97 additions and 96 deletions

View File

@@ -1,7 +1,7 @@
from collections import defaultdict from collections import defaultdict
from collections.abc import Callable from collections.abc import Callable
from uuid import UUID
from langchain.schema.messages import BaseMessage
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig from onyx.agents.agent_search.models import AgentSearchConfig
@@ -13,17 +13,18 @@ from onyx.chat.models import AnswerStream
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
@@ -33,15 +34,14 @@ logger = setup_logger()
class Answer: class Answer:
def __init__( def __init__(
self, self,
question: str, prompt_builder: AnswerPromptBuilder,
answer_style_config: AnswerStyleConfig, answer_style_config: AnswerStyleConfig,
llm: LLM, llm: LLM,
prompt_config: PromptConfig, fast_llm: LLM,
force_use_tool: ForceUseTool, force_use_tool: ForceUseTool,
agent_search_config: AgentSearchConfig, search_request: SearchRequest,
# must be the same length as `docs`. If None, all docs are considered "relevant" chat_session_id: UUID,
message_history: list[PreviousMessage] | None = None, current_agent_message_id: int,
single_message_history: str | None = None,
# newly passed in files to include as part of this question # newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED # TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None, latest_query_files: list[InMemoryChatFile] | None = None,
@@ -52,28 +52,18 @@ class Answer:
skip_explicit_tool_calling: bool = False, skip_explicit_tool_calling: bool = False,
skip_gen_ai_answer_generation: bool = False, skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None, is_connected: Callable[[], bool] | None = None,
fast_llm: LLM | None = None,
db_session: Session | None = None, db_session: Session | None = None,
use_agentic_search: bool = False,
) -> None: ) -> None:
if single_message_history and message_history:
raise ValueError(
"Cannot provide both `message_history` and `single_message_history`"
)
self.question = question
self.is_connected: Callable[[], bool] | None = is_connected self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or [] self.latest_query_files = latest_query_files or []
self.tools = tools or [] self.tools = tools or []
self.force_use_tool = force_use_tool self.force_use_tool = force_use_tool
self.message_history = message_history or []
# used for QA flow where we only want to send a single message # used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
self.llm = llm self.llm = llm
self.fast_llm = fast_llm self.fast_llm = fast_llm
@@ -82,8 +72,6 @@ class Answer:
model_name=llm.config.model_name, model_name=llm.config.model_name,
) )
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None self._streamed_output: list[str] | None = None
self._processed_stream: (list[AnswerPacket] | None) = None self._processed_stream: (list[AnswerPacket] | None) = None
@@ -97,6 +85,42 @@ class Answer:
and not skip_explicit_tool_calling and not skip_explicit_tool_calling
) )
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=force_use_tool,
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)
self.agent_search_config = agent_search_config self.agent_search_config = agent_search_config
self.db_session = db_session self.db_session = db_session

View File

@@ -284,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
class PromptConfig(BaseModel): class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed """Final representation of the Prompt configuration passed
into the `Answer` object.""" into the `PromptBuilder` object."""
system_prompt: str system_prompt: str
task_prompt: str task_prompt: str

View File

@@ -8,7 +8,6 @@ from typing import cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona from onyx.chat.chat_utils import create_temporary_persona
@@ -134,7 +133,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID, SECTION_RELEVANCE_LIST_ID,
) )
from onyx.tools.tool_runner import ToolCallFinalResult from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.telemetry import mt_cloud_telemetry
@@ -712,6 +710,7 @@ def stream_chat_message_objects(
for tool_list in tool_dict.values(): for tool_list in tool_dict.values():
tools.extend(tool_list) tools.extend(tool_list)
# TODO: unify message history with single message history
message_history = [ message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
] ]
@@ -739,26 +738,7 @@ def stream_chat_message_objects(
else None else None
), ),
) )
# TODO: Since we're deleting the current main path in Answer,
# we should construct this unconditionally inside Answer instead
# Leaving it here for the time being to avoid breaking changes
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
if len(search_tools) > 1:
# TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if new_msg_req.use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
force_use_tool = _get_force_search_settings(new_msg_req, tools) force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder( prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message( user_message=default_build_user_message(
@@ -776,35 +756,12 @@ def stream_chat_message_objects(
) )
prompt_builder.update_system_prompt(default_build_system_message(prompt_config)) prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
force_use_tool=force_use_tool,
use_agentic_search=new_msg_req.use_agentic_search,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=new_msg_req.structured_response_format,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
# TODO: add previous messages, answer style config, tools, etc.
# LLM prompt building, response capturing, etc. # LLM prompt building, response capturing, etc.
answer = Answer( answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected, is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files, latest_query_files=latest_query_files,
answer_style_config=answer_style_config, answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=( llm=(
llm llm
or get_main_llm_from_tuple( or get_main_llm_from_tuple(
@@ -818,12 +775,13 @@ def stream_chat_message_objects(
) )
), ),
fast_llm=fast_llm, fast_llm=fast_llm,
message_history=message_history,
tools=tools,
force_use_tool=force_use_tool, force_use_tool=force_use_tool,
single_message_history=single_message_history, search_request=search_request,
agent_search_config=agent_search_config, chat_session_id=chat_session_id,
current_agent_message_id=reserved_message_id,
tools=tools,
db_session=db_session, db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
) )
# reference_db_search_docs = None # reference_db_search_docs = None

View File

@@ -2,6 +2,7 @@ import json
from typing import cast from typing import cast
from unittest.mock import MagicMock from unittest.mock import MagicMock
from unittest.mock import Mock from unittest.mock import Mock
from uuid import UUID
import pytest import pytest
from langchain_core.messages import AIMessageChunk from langchain_core.messages import AIMessageChunk
@@ -11,7 +12,6 @@ from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk from langchain_core.messages import ToolCallChunk
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo from onyx.chat.models import CitationInfo
@@ -21,6 +21,10 @@ from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult from onyx.tools.models import ToolCallFinalResult
@@ -39,15 +43,28 @@ def answer_instance(
mock_llm: LLM, mock_llm: LLM,
answer_style_config: AnswerStyleConfig, answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig, prompt_config: PromptConfig,
agent_search_config: AgentSearchConfig,
) -> Answer: ) -> Answer:
return Answer( return Answer(
question=QUERY, prompt_builder=AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=QUERY,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
system_message=default_build_system_message(prompt_config),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
),
answer_style_config=answer_style_config, answer_style_config=answer_style_config,
llm=mock_llm, llm=mock_llm,
prompt_config=prompt_config, fast_llm=mock_llm,
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None), force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
agent_search_config=agent_search_config, search_request=SearchRequest(query=QUERY),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
) )
@@ -57,6 +74,8 @@ def test_basic_answer(answer_instance: Answer) -> None:
AIMessageChunk(content="This is a "), AIMessageChunk(content="This is a "),
AIMessageChunk(content="mock answer."), AIMessageChunk(content="mock answer."),
] ]
answer_instance.agent_search_config.fast_llm = mock_llm
answer_instance.agent_search_config.primary_llm = mock_llm
output = list(answer_instance.processed_streamed_output) output = list(answer_instance.processed_streamed_output)
assert len(output) == 2 assert len(output) == 2

View File

@@ -1,13 +1,16 @@
from typing import Any from typing import Any
from unittest.mock import Mock from unittest.mock import Mock
from uuid import UUID
import pytest import pytest
from langchain_core.messages import HumanMessage
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@@ -30,7 +33,6 @@ def test_skip_gen_ai_answer_generation_flag(
config: dict[str, Any], config: dict[str, Any],
mock_search_tool: SearchTool, mock_search_tool: SearchTool,
answer_style_config: AnswerStyleConfig, answer_style_config: AnswerStyleConfig,
agent_search_config: AgentSearchConfig,
prompt_config: PromptConfig, prompt_config: PromptConfig,
) -> None: ) -> None:
question = config["question"] question = config["question"]
@@ -42,30 +44,28 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.stream = Mock() mock_llm.stream = Mock()
mock_llm.stream.return_value = [Mock()] mock_llm.stream.return_value = [Mock()]
agent_search_config.primary_llm = mock_llm
agent_search_config.fast_llm = mock_llm
agent_search_config.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
agent_search_config.search_tool = mock_search_tool
agent_search_config.using_tool_calling_llm = False
agent_search_config.tools = [mock_search_tool]
answer = Answer( answer = Answer(
question=question,
answer_style_config=answer_style_config, answer_style_config=answer_style_config,
prompt_config=prompt_config,
llm=mock_llm, llm=mock_llm,
single_message_history="history", fast_llm=mock_llm,
tools=[mock_search_tool], tools=[mock_search_tool],
force_use_tool=( force_use_tool=ForceUseTool(
ForceUseTool(
tool_name=mock_search_tool.name, tool_name=mock_search_tool.name,
args={"query": question}, args={"query": question},
force_use=True, force_use=True,
)
), ),
skip_explicit_tool_calling=True, skip_explicit_tool_calling=True,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
agent_search_config=agent_search_config, search_request=SearchRequest(query=question),
prompt_builder=AnswerPromptBuilder(
user_message=HumanMessage(content=question),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=question,
raw_user_uploaded_files=[],
),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
) )
results = list(answer.processed_streamed_output) results = list(answer.processed_streamed_output)
for res in results: for res in results: