mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
AgentPromptConfig in Answer class
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
@@ -13,17 +13,18 @@ from onyx.chat.models import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -33,15 +34,14 @@ logger = setup_logger()
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
question: str,
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
fast_llm: LLM,
|
||||
force_use_tool: ForceUseTool,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
search_request: SearchRequest,
|
||||
chat_session_id: UUID,
|
||||
current_agent_message_id: int,
|
||||
# newly passed in files to include as part of this question
|
||||
# TODO THIS NEEDS TO BE HANDLED
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
@@ -52,28 +52,18 @@ class Answer:
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
fast_llm: LLM | None = None,
|
||||
db_session: Session | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
"Cannot provide both `message_history` and `single_message_history`"
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.message_history = message_history or []
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
@@ -82,8 +72,6 @@ class Answer:
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: (list[AnswerPacket] | None) = None
|
||||
|
||||
@@ -97,6 +85,42 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
logger.warning("Multiple search tools found, using first one")
|
||||
search_tool = search_tools[0]
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
else:
|
||||
logger.warning("No search tool found")
|
||||
if use_agentic_search:
|
||||
raise ValueError("No search tool found, cannot use agentic search")
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=answer_style_config.structured_response_format,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
self.agent_search_config = agent_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
|
@@ -284,7 +284,7 @@ class AnswerStyleConfig(BaseModel):
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
into the `PromptBuilder` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
|
@@ -8,7 +8,6 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
@@ -134,7 +133,6 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -712,6 +710,7 @@ def stream_chat_message_objects(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
@@ -739,26 +738,7 @@ def stream_chat_message_objects(
|
||||
else None
|
||||
),
|
||||
)
|
||||
# TODO: Since we're deleting the current main path in Answer,
|
||||
# we should construct this unconditionally inside Answer instead
|
||||
# Leaving it here for the time being to avoid breaking changes
|
||||
search_tools = [tool for tool in tools if isinstance(tool, SearchTool)]
|
||||
search_tool: SearchTool | None = None
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
logger.warning("Multiple search tools found, using first one")
|
||||
search_tool = search_tools[0]
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
else:
|
||||
logger.warning("No search tool found")
|
||||
if new_msg_req.use_agentic_search:
|
||||
raise ValueError("No search tool found, cannot use agentic search")
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
@@ -776,35 +756,12 @@ def stream_chat_message_objects(
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
prompt_builder=prompt_builder,
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_main_llm_from_tuple(
|
||||
@@ -818,12 +775,13 @@ def stream_chat_message_objects(
|
||||
)
|
||||
),
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=force_use_tool,
|
||||
single_message_history=single_message_history,
|
||||
agent_search_config=agent_search_config,
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
current_agent_message_id=reserved_message_id,
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
# reference_db_search_docs = None
|
||||
|
@@ -2,6 +2,7 @@ import json
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
@@ -11,7 +12,6 @@ from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
@@ -21,6 +21,10 @@ from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
@@ -39,15 +43,28 @@ def answer_instance(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
question=QUERY,
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=QUERY,
|
||||
prompt_config=prompt_config,
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=[],
|
||||
llm_config=mock_llm.config,
|
||||
raw_user_query=QUERY,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
answer_style_config=answer_style_config,
|
||||
llm=mock_llm,
|
||||
prompt_config=prompt_config,
|
||||
fast_llm=mock_llm,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
agent_search_config=agent_search_config,
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -57,6 +74,8 @@ def test_basic_answer(answer_instance: Answer) -> None:
|
||||
AIMessageChunk(content="This is a "),
|
||||
AIMessageChunk(content="mock answer."),
|
||||
]
|
||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
||||
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
assert len(output) == 2
|
||||
|
@@ -1,13 +1,16 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||
@@ -30,7 +33,6 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
config: dict[str, Any],
|
||||
mock_search_tool: SearchTool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
prompt_config: PromptConfig,
|
||||
) -> None:
|
||||
question = config["question"]
|
||||
@@ -42,30 +44,28 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_llm.stream = Mock()
|
||||
mock_llm.stream.return_value = [Mock()]
|
||||
|
||||
agent_search_config.primary_llm = mock_llm
|
||||
agent_search_config.fast_llm = mock_llm
|
||||
agent_search_config.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
agent_search_config.search_tool = mock_search_tool
|
||||
agent_search_config.using_tool_calling_llm = False
|
||||
agent_search_config.tools = [mock_search_tool]
|
||||
|
||||
answer = Answer(
|
||||
question=question,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
llm=mock_llm,
|
||||
single_message_history="history",
|
||||
fast_llm=mock_llm,
|
||||
tools=[mock_search_tool],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
force_use_tool=ForceUseTool(
|
||||
tool_name=mock_search_tool.name,
|
||||
args={"query": question},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
skip_explicit_tool_calling=True,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
agent_search_config=agent_search_config,
|
||||
search_request=SearchRequest(query=question),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=question),
|
||||
message_history=[],
|
||||
llm_config=mock_llm.config,
|
||||
raw_user_query=question,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
)
|
||||
results = list(answer.processed_streamed_output)
|
||||
for res in results:
|
||||
|
Reference in New Issue
Block a user