first pass at dead code deletion

This commit is contained in:
Evan Lohn 2025-01-29 14:28:46 -08:00
parent 3d99ad7bc4
commit 6c7f8eaefb
12 changed files with 68 additions and 494 deletions

View File

@ -80,7 +80,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
agent_metrics=combined_agent_metrics,
)
if agent_a_config.use_persistence:
if agent_a_config.use_agentic_persistence:
# Persist the sub-answer in the database
db_session = agent_a_config.db_session
chat_session_id = agent_a_config.chat_session_id

View File

@ -1,4 +1,3 @@
from dataclasses import dataclass
from uuid import UUID
from pydantic import BaseModel
@ -15,8 +14,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
@dataclass
class AgentSearchConfig:
class AgentSearchConfig(BaseModel):
"""
Configuration for the Agent Search feature.
"""
@ -47,10 +45,10 @@ class AgentSearchConfig:
# The message ID of the user message that triggered the Pro Search
message_id: int | None = None
# Whether to persistence data for the Pro Search (turned off for testing)
use_persistence: bool = True
# Whether to persistence data for Agentic Search (turned off for testing)
use_agentic_persistence: bool = True
# The database session for the Pro Search
# The database session for Agentic Search
db_session: Session | None = None
# Whether to perform initial search to inform decomposition
@ -75,7 +73,7 @@ class AgentSearchConfig:
@model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig":
if self.use_persistence and self.db_session is None:
if self.use_agentic_persistence and self.db_session is None:
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
@ -87,6 +85,9 @@ class AgentSearchConfig:
raise ValueError("search_tool must be provided for agentic search")
return self
class Config:
arbitrary_types_allowed = True
class AgentDocumentCitations(BaseModel):
document_id: str

View File

@ -38,7 +38,6 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
# TODO: custom events for yields
emit_packet(tool_kickoff)
tool_responses = []

View File

@ -155,26 +155,19 @@ def run_graph(
# TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
main_graph_builder = (
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
)
def load_compiled_graph() -> CompiledStateGraph:
global _COMPILED_GRAPH
if _COMPILED_GRAPH is None:
graph = main_graph_builder()
graph = main_graph_builder_a()
_COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH
def run_main_graph(
config: AgentSearchConfig,
graph_name: str = "a",
) -> AnswerStream:
compiled_graph = load_compiled_graph(graph_name)
if graph_name == "a":
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
else:
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
compiled_graph = load_compiled_graph()
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff(
@ -202,11 +195,36 @@ if __name__ == "__main__":
now_start = datetime.now()
logger.debug(f"Start at {now_start}")
if GRAPH_VERSION_NAME == "a":
graph = main_graph_builder_a()
else:
graph = main_graph_builder_a()
compiled_graph = graph.compile()
graph = main_graph_builder_a()
now_start = datetime.now()
compiled_graph = graph.compile()
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
@ -226,7 +244,7 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_persistence = True
config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
if GRAPH_VERSION_NAME == "a":

View File

@ -248,7 +248,7 @@ def get_test_config(
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1,
use_persistence=True,
use_agentic_persistence=True,
db_session=db_session,
tools=[search_tool],
use_agentic_search=use_agentic_search,

View File

@ -54,6 +54,7 @@ class Answer:
is_connected: Callable[[], bool] | None = None,
db_session: Session | None = None,
use_agentic_search: bool = False,
use_agentic_persistence: bool = True,
) -> None:
self.is_connected: Callable[[], bool] | None = is_connected
@ -90,14 +91,9 @@ class Answer:
if len(search_tools) > 1:
# TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one")
search_tool = search_tools[0]
raise ValueError("Multiple search tools found")
elif len(search_tools) == 1:
search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
@ -111,7 +107,7 @@ class Answer:
use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id,
message_id=current_agent_message_id,
use_persistence=True,
use_agentic_persistence=use_agentic_persistence,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
@ -137,104 +133,20 @@ class Answer:
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool]
# TODO: delete the function and move the full body to processed_streamed_output
def _get_response(self) -> AnswerStream:
# current_llm_call = llm_calls[-1]
@property
def processed_streamed_output(self) -> AnswerStream:
if self._processed_stream is not None:
yield from self._processed_stream
return
# tool, tool_args = None, None
# # handle the case where no decision has to be made; we simply run the tool
# if (
# current_llm_call.force_use_tool.force_use
# and current_llm_call.force_use_tool.args is not None
# ):
# tool_name, tool_args = (
# current_llm_call.force_use_tool.tool_name,
# current_llm_call.force_use_tool.args,
# )
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
# # special pre-logic for non-tool calling LLM case
# elif not self.using_tool_calling_llm and current_llm_call.tools:
# chosen_tool_and_args = (
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
# current_llm_call, self.llm
# )
# )
# if chosen_tool_and_args:
# tool, tool_args = chosen_tool_and_args
# if tool and tool_args:
# dummy_tool_call_chunk = AIMessageChunk(content="")
# dummy_tool_call_chunk.tool_calls = [
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
# ]
# response_handler_manager = LLMResponseHandlerManager(
# ToolResponseHandler([tool]), None, self.is_cancelled
# )
# yield from response_handler_manager.handle_llm_response(
# iter([dummy_tool_call_chunk])
# )
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
# if tmp_call is None:
# return # no more LLM calls to process
# current_llm_call = tmp_call
# # if we're skipping gen ai answer generation, we should break
# # out unless we're forcing a tool call. If we don't, we might generate an
# # answer, which is a no-no!
# if (
# self.skip_gen_ai_answer_generation
# and not current_llm_call.force_use_tool.force_use
# ):
# return
# # set up "handlers" to listen to the LLM response stream and
# # feed back the processed results + handle tool call requests
# # + figure out what the next LLM call should be
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# final_search_results, displayed_search_results = SearchTool.get_search_result(
# current_llm_call
# ) or ([], [])
# # NEXT: we still want to handle the LLM response stream, but it is now:
# # 1. handle the tool call requests
# # 2. feed back the processed results
# # 3. handle the citations
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
# response_handler_manager = LLMResponseHandlerManager(
# tool_call_handler, answer_handler, self.is_cancelled
# )
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.agent_search_config.use_agentic_search:
if (
self.agent_search_config.db_session is None
and self.agent_search_config.use_persistence
):
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
stream = run_main_graph(
config=self.agent_search_config,
)
else:
stream = run_basic_graph(
config=self.agent_search_config,
)
run_langgraph = (
run_main_graph
if self.agent_search_config.use_agentic_search
else run_basic_graph
)
stream = run_langgraph(
self.agent_search_config,
)
processed_stream = []
for packet in stream:
@ -244,62 +156,6 @@ class Answer:
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
# stream = self.llm.stream(
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
# prompt=current_llm_call.prompt_builder.build(),
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
# tool_choice=(
# "required"
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
# else None
# ),
# structured_response_format=self.answer_style_config.structured_response_format,
# )
# yield from response_handler_manager.handle_llm_response(stream)
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
# if new_llm_call:
# yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
if self._processed_stream is not None:
yield from self._processed_stream
return
# prompt_builder = AnswerPromptBuilder(
# user_message=default_build_user_message(
# user_query=self.question,
# prompt_config=self.prompt_config,
# files=self.latest_query_files,
# single_message_history=self.single_message_history,
# ),
# message_history=self.message_history,
# llm_config=self.llm.config,
# raw_user_query=self.question,
# raw_user_uploaded_files=self.latest_query_files or [],
# single_message_history=self.single_message_history,
# )
# prompt_builder.update_system_prompt(
# default_build_system_message(self.prompt_config)
# )
# llm_call = LLMCall(
# prompt_builder=prompt_builder,
# tools=self._get_tools_list(),
# force_use_tool=self.force_use_tool,
# files=self.latest_query_files,
# tool_call_info=[],
# using_tool_calling_llm=self.using_tool_calling_llm,
# )
processed_stream = []
for processed_packet in self._get_response():
processed_stream.append(processed_packet)
yield processed_packet
self._processed_stream = processed_stream
@ -343,6 +199,7 @@ class Answer:
return citations
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
citations_by_subquestion: dict[
tuple[int, int], list[CitationInfo]

View File

@ -1,7 +1,5 @@
import abc
from collections.abc import Generator
from typing import Any
from typing import cast
from langchain_core.messages import BaseMessage
@ -26,10 +24,6 @@ class AnswerResponseHandler(abc.ABC):
) -> Generator[ResponsePart, None, None]:
raise NotImplementedError
@abc.abstractmethod
def update(self, state_update: Any) -> None:
raise NotImplementedError
class PassThroughAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
@ -40,9 +34,6 @@ class PassThroughAnswerResponseHandler(AnswerResponseHandler):
content = _message_to_str(response_item)
yield OnyxAnswerPiece(answer_piece=content)
def update(self, state_update: Any) -> None:
pass
class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part(
@ -53,9 +44,6 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
# This is a dummy handler that returns nothing
yield from []
def update(self, state_update: Any) -> None:
pass
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
@ -91,20 +79,6 @@ class CitationResponseHandler(AnswerResponseHandler):
# Process the new content through the citation processor
yield from self.citation_processor.process_token(content)
def update(self, state_update: Any) -> None:
state = cast(
tuple[list[LlmDoc], DocumentIdOrderMapping, DocumentIdOrderMapping],
state_update,
)
self.context_docs = state[0]
self.final_doc_id_to_rank_map = state[1]
self.display_doc_id_to_rank_map = state[2]
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
)
def _message_to_str(message: BaseMessage | str | None) -> str:
if message is None:
@ -116,80 +90,3 @@ def _message_to_str(message: BaseMessage | str | None) -> str:
logger.warning(f"Received non-string content: {type(content)}")
content = str(content) if content is not None else ""
return content
# class CitationMultiResponseHandler(AnswerResponseHandler):
# def __init__(self) -> None:
# self.channel_processors: dict[str, CitationProcessor] = {}
# self._default_channel = "__default__"
# def register_default_channel(
# self,
# context_docs: list[LlmDoc],
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
# ) -> None:
# """Register the default channel with its associated documents and ranking maps."""
# self.register_channel(
# channel_id=self._default_channel,
# context_docs=context_docs,
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
# )
# def register_channel(
# self,
# channel_id: str,
# context_docs: list[LlmDoc],
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
# ) -> None:
# """Register a new channel with its associated documents and ranking maps."""
# self.channel_processors[channel_id] = CitationProcessor(
# context_docs=context_docs,
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
# )
# def handle_response_part(
# self,
# response_item: BaseMessage | str | None,
# previous_response_items: list[BaseMessage | str],
# ) -> Generator[ResponsePart, None, None]:
# """Default implementation that uses the default channel."""
# yield from self.handle_channel_response(
# response_item=content,
# previous_response_items=previous_response_items,
# channel_id=self._default_channel,
# )
# def handle_channel_response(
# self,
# response_item: ResponsePart | str | None,
# previous_response_items: list[ResponsePart | str],
# channel_id: str,
# ) -> Generator[ResponsePart, None, None]:
# """Process a response part for a specific channel."""
# if channel_id not in self.channel_processors:
# raise ValueError(f"Attempted to process response for unregistered channel {channel_id}")
# if response_item is None:
# return
# content = (
# response_item.content if isinstance(response_item, BaseMessage) else response_item
# )
# # Ensure content is a string
# if not isinstance(content, str):
# logger.warning(f"Received non-string content: {type(content)}")
# content = str(content) if content is not None else ""
# # Process the new content through the channel's citation processor
# yield from self.channel_processors[channel_id].multi_process_token(content)
# def remove_channel(self, channel_id: str) -> None:
# """Remove a channel and its associated processor."""
# if channel_id in self.channel_processors:
# del self.channel_processors[channel_id]

View File

@ -4,7 +4,6 @@ from collections.abc import Generator
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.configs.chat_configs import STOP_STREAM_PAT
from onyx.prompts.constants import TRIPLE_BACKTICK
@ -41,164 +40,6 @@ class CitationProcessor:
self.current_citations: list[int] = []
self.past_cite_count = 0
# TODO: should reference previous citation processing, rework previous, or completely use new one?
def multi_process_token(
self, parsed_object: ResponsePart
) -> Generator[ResponsePart, None, None]:
# if isinstance(parsed_object,OnyxAnswerPiece):
# # standard citation processing
# yield from self.process_token(parsed_object.answer_piece)
# elif isinstance(parsed_object, AgentAnswerPiece):
# # citation processing for agent answer pieces
# for token in self.process_token(parsed_object.answer_piece):
# if isinstance(token, CitationInfo):
# yield token
# else:
# yield AgentAnswerPiece(answer_piece=token.answer_piece or '',
# answer_type=parsed_object.answer_type, level=parsed_object.level,
# level_question_nr=parsed_object.level_question_nr)
# level = getattr(parsed_object, "level", None)
# level_question_nr = getattr(parsed_object, "level_question_nr", None)
# if isinstance(parsed_object, (AgentAnswerPiece, OnyxAnswerPiece)):
# # logger.debug(f"FA {parsed_object.answer_piece}")
# if isinstance(parsed_object, AgentAnswerPiece):
# token = parsed_object.answer_piece
# level = parsed_object.level
# level_question_nr = parsed_object.level_question_nr
# else:
# yield parsed_object
# return
# # raise ValueError(
# # f"Invalid parsed object type: {type(parsed_object)}"
# # )
# if not citation_potential[level][level_question_nr] and token:
# if token.startswith(" ["):
# citation_potential[level][level_question_nr] = True
# current_yield_components[level][level_question_nr] = [token]
# else:
# yield parsed_object
# elif token and citation_potential[level][level_question_nr]:
# current_yield_components[level][level_question_nr].append(token)
# current_yield_str[level][level_question_nr] = "".join(
# current_yield_components[level][level_question_nr]
# )
# if current_yield_str[level][level_question_nr].strip().startswith(
# "[D"
# ) or current_yield_str[level][level_question_nr].strip().startswith(
# "[Q"
# ):
# citation_potential[level][level_question_nr] = True
# else:
# citation_potential[level][level_question_nr] = False
# parsed_object = _set_combined_token_value(
# current_yield_str[level][level_question_nr], parsed_object
# )
# yield parsed_object
# if (
# len(current_yield_components[level][level_question_nr]) > 15
# ): # ??? 15?
# citation_potential[level][level_question_nr] = False
# parsed_object = _set_combined_token_value(
# current_yield_str[level][level_question_nr], parsed_object
# )
# yield parsed_object
# elif "]" in current_yield_str[level][level_question_nr]:
# section_split = current_yield_str[level][level_question_nr].split(
# "]"
# )
# section_split[0] + "]" # dead code?
# start_of_next_section = "]".join(section_split[1:])
# citation_string = current_yield_str[level][level_question_nr][
# : -len(start_of_next_section)
# ]
# if "[D" in citation_string:
# cite_open_bracket_marker, cite_close_bracket_marker = (
# "[",
# "]",
# )
# cite_identifyer = "D"
# try:
# cited_document = int(
# citation_string[level][level_question_nr][2:-1]
# )
# if level and level_question_nr:
# link = agent_document_citations[int(level)][
# int(level_question_nr)
# ][cited_document].link
# else:
# link = ""
# except (ValueError, IndexError):
# link = ""
# elif "[Q" in citation_string:
# cite_open_bracket_marker, cite_close_bracket_marker = (
# "{",
# "}",
# )
# cite_identifyer = "Q"
# else:
# pass
# citation_string = citation_string.replace(
# "[" + cite_identifyer,
# cite_open_bracket_marker * 2,
# ).replace("]", cite_close_bracket_marker * 2)
# if cite_identifyer == "D":
# citation_string += f"({link})"
# parsed_object = _set_combined_token_value(
# citation_string, parsed_object
# )
# yield parsed_object
# current_yield_components[level][level_question_nr] = [
# start_of_next_section
# ]
# if not start_of_next_section.strip().startswith("["):
# citation_potential[level][level_question_nr] = False
# elif isinstance(parsed_object, ExtendedToolResponse):
# if parsed_object.id == "search_response_summary":
# level = parsed_object.level
# level_question_nr = parsed_object.level_question_nr
# for inference_section in parsed_object.response.top_sections:
# doc_link = inference_section.center_chunk.source_links[0]
# doc_title = inference_section.center_chunk.title
# doc_id = inference_section.center_chunk.document_id
# if (
# doc_id
# not in agent_question_citations_used_docs[level][
# level_question_nr
# ]
# ):
# if level not in agent_document_citations:
# agent_document_citations[level] = {}
# if level_question_nr not in agent_document_citations[level]:
# agent_document_citations[level][level_question_nr] = []
# agent_document_citations[level][level_question_nr].append(
# AgentDocumentCitations(
# document_id=doc_id,
# document_title=doc_title,
# link=doc_link,
# )
# )
# agent_question_citations_used_docs[level][
# level_question_nr
# ].append(doc_id)
yield parsed_object
def process_token(
self, token: str | None
) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]:

View File

@ -41,6 +41,7 @@ DEFAULT_CC_PAIR_ID = 1
# subquestion level and question number for basic flow
BASIC_KEY = (-1, -1)
AGENT_SEARCH_INITIAL_KEY = (0, 0)
CANCEL_CHECK_INTERVAL = 20
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"

View File

@ -5,7 +5,6 @@ from unittest.mock import MagicMock
import pytest
from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
@ -14,22 +13,17 @@ from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.utils import explicit_tool_calling_supported
QUERY = "Test question"
DEFAULT_SEARCH_ARGS = {"query": "search"}
@ -40,43 +34,6 @@ def answer_style_config() -> AnswerStyleConfig:
return AnswerStyleConfig(citation_config=CitationConfig())
@pytest.fixture
def agent_search_config(
mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig
) -> AgentSearchConfig:
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=QUERY,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
single_message_history=None,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
using_tool_calling_llm = explicit_tool_calling_supported(
mock_llm.config.model_provider, mock_llm.config.model_name
)
return AgentSearchConfig(
search_request=SearchRequest(query=QUERY),
primary_llm=mock_llm,
fast_llm=mock_llm,
search_tool=mock_search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=prompt_builder,
chat_session_id=None,
message_id=1,
use_persistence=True,
db_session=None,
use_agentic_search=False,
using_tool_calling_llm=using_tool_calling_llm,
)
@pytest.fixture
def prompt_config() -> PromptConfig:
return PromptConfig(
@ -89,7 +46,7 @@ def prompt_config() -> PromptConfig:
@pytest.fixture
def mock_llm() -> MagicMock:
mock_llm_obj = MagicMock()
mock_llm_obj = MagicMock(spec=LLM)
mock_llm_obj.config = LLMConfig(
model_provider="openai",
model_name="gpt-4o",

View File

@ -65,6 +65,7 @@ def answer_instance(
search_request=SearchRequest(query=QUERY),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
use_agentic_persistence=False,
)

View File

@ -11,6 +11,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@ -38,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
mock_llm = Mock()
mock_llm = Mock(spec=LLM)
mock_llm.config = Mock()
mock_llm.config.model_name = "gpt-4o-mini"
mock_llm.stream = Mock()
@ -66,6 +67,7 @@ def test_skip_gen_ai_answer_generation_flag(
),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
use_agentic_persistence=False,
)
results = list(answer.processed_streamed_output)
for res in results: