reworked config to have logical structure

This commit is contained in:
Evan Lohn
2025-01-31 15:37:47 -08:00
parent 8342168658
commit 118e8afbef
33 changed files with 296 additions and 426 deletions

View File

@@ -4,7 +4,11 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.models import GraphInputs
from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
@@ -16,12 +20,10 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import SearchRequest
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@@ -57,35 +59,9 @@ class Answer:
use_agentic_persistence: bool = True,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.tools = tools or []
self.force_use_tool = force_use_tool
# used for QA flow where we only want to send a single message
self.answer_style_config = answer_style_config
self.llm = llm
self.fast_llm = fast_llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
model_name=llm.config.model_name,
)
self._streamed_output: list[str] | None = None
self._processed_stream: (list[AnswerPacket] | None) = None
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.using_tool_calling_llm = (
explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not skip_explicit_tool_calling
)
search_tools = [tool for tool in (tools or []) if isinstance(tool, SearchTool)]
search_tool: SearchTool | None = None
@@ -95,43 +71,46 @@ class Answer:
elif len(search_tools) == 1:
search_tool = search_tools[0]
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
using_tool_calling_llm = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
and not skip_explicit_tool_calling
)
self.agent_search_config = AgentSearchConfig(
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
)
self.graph_tooling = GraphTooling(
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
tools=tools or [],
force_use_tool=force_use_tool,
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_agentic_persistence=use_agentic_persistence,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=answer_style_config.structured_response_format,
)
self.graph_persistence = None
if use_agentic_persistence:
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
)
self.search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
)
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
return self.tools
tool = get_tool_by_name(self.tools, self.force_use_tool.tool_name)
args_str = (
f" with args='{self.force_use_tool.args}'"
if self.force_use_tool.args
else ""
self.graph_config = GraphConfig(
inputs=self.graph_inputs,
tooling=self.graph_tooling,
persistence=self.graph_persistence,
behavior=self.search_behavior_config,
)
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -141,11 +120,11 @@ class Answer:
run_langgraph = (
run_main_graph
if self.agent_search_config.use_agentic_search
if self.graph_config.behavior.use_agentic_search
else run_basic_graph
)
stream = run_langgraph(
self.agent_search_config,
self.graph_config,
)
processed_stream = []