first pass at dead code deletion

This commit is contained in:
Evan Lohn
2025-01-29 14:28:46 -08:00
parent 3d99ad7bc4
commit 6c7f8eaefb
12 changed files with 68 additions and 494 deletions

View File

@@ -80,7 +80,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
agent_metrics=combined_agent_metrics, agent_metrics=combined_agent_metrics,
) )
if agent_a_config.use_persistence: if agent_a_config.use_agentic_persistence:
# Persist the sub-answer in the database # Persist the sub-answer in the database
db_session = agent_a_config.db_session db_session = agent_a_config.db_session
chat_session_id = agent_a_config.chat_session_id chat_session_id = agent_a_config.chat_session_id

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel
@@ -15,8 +14,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
@dataclass class AgentSearchConfig(BaseModel):
class AgentSearchConfig:
""" """
Configuration for the Agent Search feature. Configuration for the Agent Search feature.
""" """
@@ -47,10 +45,10 @@ class AgentSearchConfig:
# The message ID of the user message that triggered the Pro Search # The message ID of the user message that triggered the Pro Search
message_id: int | None = None message_id: int | None = None
# Whether to persistence data for the Pro Search (turned off for testing) # Whether to persistence data for Agentic Search (turned off for testing)
use_persistence: bool = True use_agentic_persistence: bool = True
# The database session for the Pro Search # The database session for Agentic Search
db_session: Session | None = None db_session: Session | None = None
# Whether to perform initial search to inform decomposition # Whether to perform initial search to inform decomposition
@@ -75,7 +73,7 @@ class AgentSearchConfig:
@model_validator(mode="after") @model_validator(mode="after")
def validate_db_session(self) -> "AgentSearchConfig": def validate_db_session(self) -> "AgentSearchConfig":
if self.use_persistence and self.db_session is None: if self.use_agentic_persistence and self.db_session is None:
raise ValueError( raise ValueError(
"db_session must be provided for pro search when using persistence" "db_session must be provided for pro search when using persistence"
) )
@@ -87,6 +85,9 @@ class AgentSearchConfig:
raise ValueError("search_tool must be provided for agentic search") raise ValueError("search_tool must be provided for agentic search")
return self return self
class Config:
arbitrary_types_allowed = True
class AgentDocumentCitations(BaseModel): class AgentDocumentCitations(BaseModel):
document_id: str document_id: str

View File

@@ -38,7 +38,6 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate
tool_runner = ToolRunner(tool, tool_args) tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff() tool_kickoff = tool_runner.kickoff()
# TODO: custom events for yields
emit_packet(tool_kickoff) emit_packet(tool_kickoff)
tool_responses = [] tool_responses = []

View File

@@ -155,26 +155,19 @@ def run_graph(
# TODO: call this once on startup, TBD where and if it should be gated based # TODO: call this once on startup, TBD where and if it should be gated based
# on dev mode or not # on dev mode or not
def load_compiled_graph(graph_name: str) -> CompiledStateGraph: def load_compiled_graph() -> CompiledStateGraph:
main_graph_builder = (
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
)
global _COMPILED_GRAPH global _COMPILED_GRAPH
if _COMPILED_GRAPH is None: if _COMPILED_GRAPH is None:
graph = main_graph_builder() graph = main_graph_builder_a()
_COMPILED_GRAPH = graph.compile() _COMPILED_GRAPH = graph.compile()
return _COMPILED_GRAPH return _COMPILED_GRAPH
def run_main_graph( def run_main_graph(
config: AgentSearchConfig, config: AgentSearchConfig,
graph_name: str = "a",
) -> AnswerStream: ) -> AnswerStream:
compiled_graph = load_compiled_graph(graph_name) compiled_graph = load_compiled_graph()
if graph_name == "a": input = MainInput_a(base_question=config.search_request.query, log_messages=[])
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
else:
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
# Agent search is not a Tool per se, but this is helpful for the frontend # Agent search is not a Tool per se, but this is helpful for the frontend
yield ToolCallKickoff( yield ToolCallKickoff(
@@ -202,11 +195,36 @@ if __name__ == "__main__":
now_start = datetime.now() now_start = datetime.now()
logger.debug(f"Start at {now_start}") logger.debug(f"Start at {now_start}")
if GRAPH_VERSION_NAME == "a": graph = main_graph_builder_a()
graph = main_graph_builder_a() now_start = datetime.now()
else: compiled_graph = graph.compile()
graph = main_graph_builder_a() now_end = datetime.now()
compiled_graph = graph.compile() logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
# query="what can you do with gitlab?",
# query="What are the guiding principles behind the development of cockroachDB",
# query="What are the temperatures in Munich, Hawaii, and New York?",
# query="When was Washington born?",
# query="What is Onyx?",
# query="What is the difference between astronomy and astrology?",
query="Do a search to tell me what is the difference between astronomy and astrology?",
)
# Joachim custom persona
with get_session_context_manager() as db_session:
config, search_tool = get_test_config(
db_session, primary_llm, fast_llm, search_request
)
# search_request.persona = get_persona_by_id(1, None, db_session)
config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True
input = MainInput_a(
base_question=config.search_request.query, log_messages=[]
)
now_end = datetime.now() now_end = datetime.now()
logger.debug(f"Graph compiled in {now_end - now_start} seconds") logger.debug(f"Graph compiled in {now_end - now_start} seconds")
primary_llm, fast_llm = get_default_llms() primary_llm, fast_llm = get_default_llms()
@@ -226,7 +244,7 @@ if __name__ == "__main__":
db_session, primary_llm, fast_llm, search_request db_session, primary_llm, fast_llm, search_request
) )
# search_request.persona = get_persona_by_id(1, None, db_session) # search_request.persona = get_persona_by_id(1, None, db_session)
config.use_persistence = True config.use_agentic_persistence = True
# config.perform_initial_search_path_decision = False # config.perform_initial_search_path_decision = False
config.perform_initial_search_decomposition = True config.perform_initial_search_decomposition = True
if GRAPH_VERSION_NAME == "a": if GRAPH_VERSION_NAME == "a":

View File

@@ -248,7 +248,7 @@ def get_test_config(
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan # chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
message_id=1, message_id=1,
use_persistence=True, use_agentic_persistence=True,
db_session=db_session, db_session=db_session,
tools=[search_tool], tools=[search_tool],
use_agentic_search=use_agentic_search, use_agentic_search=use_agentic_search,

View File

@@ -54,6 +54,7 @@ class Answer:
is_connected: Callable[[], bool] | None = None, is_connected: Callable[[], bool] | None = None,
db_session: Session | None = None, db_session: Session | None = None,
use_agentic_search: bool = False, use_agentic_search: bool = False,
use_agentic_persistence: bool = True,
) -> None: ) -> None:
self.is_connected: Callable[[], bool] | None = is_connected self.is_connected: Callable[[], bool] | None = is_connected
@@ -90,14 +91,9 @@ class Answer:
if len(search_tools) > 1: if len(search_tools) > 1:
# TODO: handle multiple search tools # TODO: handle multiple search tools
logger.warning("Multiple search tools found, using first one") raise ValueError("Multiple search tools found")
search_tool = search_tools[0]
elif len(search_tools) == 1: elif len(search_tools) == 1:
search_tool = search_tools[0] search_tool = search_tools[0]
else:
logger.warning("No search tool found")
if use_agentic_search:
raise ValueError("No search tool found, cannot use agentic search")
using_tool_calling_llm = explicit_tool_calling_supported( using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name llm.config.model_provider, llm.config.model_name
@@ -111,7 +107,7 @@ class Answer:
use_agentic_search=use_agentic_search, use_agentic_search=use_agentic_search,
chat_session_id=chat_session_id, chat_session_id=chat_session_id,
message_id=current_agent_message_id, message_id=current_agent_message_id,
use_persistence=True, use_agentic_persistence=use_agentic_persistence,
allow_refinement=True, allow_refinement=True,
db_session=db_session, db_session=db_session,
prompt_builder=prompt_builder, prompt_builder=prompt_builder,
@@ -137,104 +133,20 @@ class Answer:
logger.info(f"Forcefully using tool='{tool.name}'{args_str}") logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
return [tool] return [tool]
# TODO: delete the function and move the full body to processed_streamed_output @property
def _get_response(self) -> AnswerStream: def processed_streamed_output(self) -> AnswerStream:
# current_llm_call = llm_calls[-1] if self._processed_stream is not None:
yield from self._processed_stream
return
# tool, tool_args = None, None run_langgraph = (
# # handle the case where no decision has to be made; we simply run the tool run_main_graph
# if ( if self.agent_search_config.use_agentic_search
# current_llm_call.force_use_tool.force_use else run_basic_graph
# and current_llm_call.force_use_tool.args is not None )
# ): stream = run_langgraph(
# tool_name, tool_args = ( self.agent_search_config,
# current_llm_call.force_use_tool.tool_name, )
# current_llm_call.force_use_tool.args,
# )
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
# # special pre-logic for non-tool calling LLM case
# elif not self.using_tool_calling_llm and current_llm_call.tools:
# chosen_tool_and_args = (
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
# current_llm_call, self.llm
# )
# )
# if chosen_tool_and_args:
# tool, tool_args = chosen_tool_and_args
# if tool and tool_args:
# dummy_tool_call_chunk = AIMessageChunk(content="")
# dummy_tool_call_chunk.tool_calls = [
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
# ]
# response_handler_manager = LLMResponseHandlerManager(
# ToolResponseHandler([tool]), None, self.is_cancelled
# )
# yield from response_handler_manager.handle_llm_response(
# iter([dummy_tool_call_chunk])
# )
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
# if tmp_call is None:
# return # no more LLM calls to process
# current_llm_call = tmp_call
# # if we're skipping gen ai answer generation, we should break
# # out unless we're forcing a tool call. If we don't, we might generate an
# # answer, which is a no-no!
# if (
# self.skip_gen_ai_answer_generation
# and not current_llm_call.force_use_tool.force_use
# ):
# return
# # set up "handlers" to listen to the LLM response stream and
# # feed back the processed results + handle tool call requests
# # + figure out what the next LLM call should be
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# final_search_results, displayed_search_results = SearchTool.get_search_result(
# current_llm_call
# ) or ([], [])
# # NEXT: we still want to handle the LLM response stream, but it is now:
# # 1. handle the tool call requests
# # 2. feed back the processed results
# # 3. handle the citations
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
# response_handler_manager = LLMResponseHandlerManager(
# tool_call_handler, answer_handler, self.is_cancelled
# )
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.agent_search_config.use_agentic_search:
if (
self.agent_search_config.db_session is None
and self.agent_search_config.use_persistence
):
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
stream = run_main_graph(
config=self.agent_search_config,
)
else:
stream = run_basic_graph(
config=self.agent_search_config,
)
processed_stream = [] processed_stream = []
for packet in stream: for packet in stream:
@@ -244,62 +156,6 @@ class Answer:
break break
processed_stream.append(packet) processed_stream.append(packet)
yield packet yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
# stream = self.llm.stream(
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
# prompt=current_llm_call.prompt_builder.build(),
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
# tool_choice=(
# "required"
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
# else None
# ),
# structured_response_format=self.answer_style_config.structured_response_format,
# )
# yield from response_handler_manager.handle_llm_response(stream)
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
# if new_llm_call:
# yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
if self._processed_stream is not None:
yield from self._processed_stream
return
# prompt_builder = AnswerPromptBuilder(
# user_message=default_build_user_message(
# user_query=self.question,
# prompt_config=self.prompt_config,
# files=self.latest_query_files,
# single_message_history=self.single_message_history,
# ),
# message_history=self.message_history,
# llm_config=self.llm.config,
# raw_user_query=self.question,
# raw_user_uploaded_files=self.latest_query_files or [],
# single_message_history=self.single_message_history,
# )
# prompt_builder.update_system_prompt(
# default_build_system_message(self.prompt_config)
# )
# llm_call = LLMCall(
# prompt_builder=prompt_builder,
# tools=self._get_tools_list(),
# force_use_tool=self.force_use_tool,
# files=self.latest_query_files,
# tool_call_info=[],
# using_tool_calling_llm=self.using_tool_calling_llm,
# )
processed_stream = []
for processed_packet in self._get_response():
processed_stream.append(processed_packet)
yield processed_packet
self._processed_stream = processed_stream self._processed_stream = processed_stream
@@ -343,6 +199,7 @@ class Answer:
return citations return citations
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]: def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
citations_by_subquestion: dict[ citations_by_subquestion: dict[
tuple[int, int], list[CitationInfo] tuple[int, int], list[CitationInfo]

View File

@@ -1,7 +1,5 @@
import abc import abc
from collections.abc import Generator from collections.abc import Generator
from typing import Any
from typing import cast
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
@@ -26,10 +24,6 @@ class AnswerResponseHandler(abc.ABC):
) -> Generator[ResponsePart, None, None]: ) -> Generator[ResponsePart, None, None]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def update(self, state_update: Any) -> None:
raise NotImplementedError
class PassThroughAnswerResponseHandler(AnswerResponseHandler): class PassThroughAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part( def handle_response_part(
@@ -40,9 +34,6 @@ class PassThroughAnswerResponseHandler(AnswerResponseHandler):
content = _message_to_str(response_item) content = _message_to_str(response_item)
yield OnyxAnswerPiece(answer_piece=content) yield OnyxAnswerPiece(answer_piece=content)
def update(self, state_update: Any) -> None:
pass
class DummyAnswerResponseHandler(AnswerResponseHandler): class DummyAnswerResponseHandler(AnswerResponseHandler):
def handle_response_part( def handle_response_part(
@@ -53,9 +44,6 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
# This is a dummy handler that returns nothing # This is a dummy handler that returns nothing
yield from [] yield from []
def update(self, state_update: Any) -> None:
pass
class CitationResponseHandler(AnswerResponseHandler): class CitationResponseHandler(AnswerResponseHandler):
def __init__( def __init__(
@@ -91,20 +79,6 @@ class CitationResponseHandler(AnswerResponseHandler):
# Process the new content through the citation processor # Process the new content through the citation processor
yield from self.citation_processor.process_token(content) yield from self.citation_processor.process_token(content)
def update(self, state_update: Any) -> None:
state = cast(
tuple[list[LlmDoc], DocumentIdOrderMapping, DocumentIdOrderMapping],
state_update,
)
self.context_docs = state[0]
self.final_doc_id_to_rank_map = state[1]
self.display_doc_id_to_rank_map = state[2]
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
)
def _message_to_str(message: BaseMessage | str | None) -> str: def _message_to_str(message: BaseMessage | str | None) -> str:
if message is None: if message is None:
@@ -116,80 +90,3 @@ def _message_to_str(message: BaseMessage | str | None) -> str:
logger.warning(f"Received non-string content: {type(content)}") logger.warning(f"Received non-string content: {type(content)}")
content = str(content) if content is not None else "" content = str(content) if content is not None else ""
return content return content
# class CitationMultiResponseHandler(AnswerResponseHandler):
# def __init__(self) -> None:
# self.channel_processors: dict[str, CitationProcessor] = {}
# self._default_channel = "__default__"
# def register_default_channel(
# self,
# context_docs: list[LlmDoc],
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
# ) -> None:
# """Register the default channel with its associated documents and ranking maps."""
# self.register_channel(
# channel_id=self._default_channel,
# context_docs=context_docs,
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
# )
# def register_channel(
# self,
# channel_id: str,
# context_docs: list[LlmDoc],
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
# ) -> None:
# """Register a new channel with its associated documents and ranking maps."""
# self.channel_processors[channel_id] = CitationProcessor(
# context_docs=context_docs,
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
# )
# def handle_response_part(
# self,
# response_item: BaseMessage | str | None,
# previous_response_items: list[BaseMessage | str],
# ) -> Generator[ResponsePart, None, None]:
# """Default implementation that uses the default channel."""
# yield from self.handle_channel_response(
# response_item=content,
# previous_response_items=previous_response_items,
# channel_id=self._default_channel,
# )
# def handle_channel_response(
# self,
# response_item: ResponsePart | str | None,
# previous_response_items: list[ResponsePart | str],
# channel_id: str,
# ) -> Generator[ResponsePart, None, None]:
# """Process a response part for a specific channel."""
# if channel_id not in self.channel_processors:
# raise ValueError(f"Attempted to process response for unregistered channel {channel_id}")
# if response_item is None:
# return
# content = (
# response_item.content if isinstance(response_item, BaseMessage) else response_item
# )
# # Ensure content is a string
# if not isinstance(content, str):
# logger.warning(f"Received non-string content: {type(content)}")
# content = str(content) if content is not None else ""
# # Process the new content through the channel's citation processor
# yield from self.channel_processors[channel_id].multi_process_token(content)
# def remove_channel(self, channel_id: str) -> None:
# """Remove a channel and its associated processor."""
# if channel_id in self.channel_processors:
# del self.channel_processors[channel_id]

View File

@@ -4,7 +4,6 @@ from collections.abc import Generator
from onyx.chat.models import CitationInfo from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import ResponsePart
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
from onyx.configs.chat_configs import STOP_STREAM_PAT from onyx.configs.chat_configs import STOP_STREAM_PAT
from onyx.prompts.constants import TRIPLE_BACKTICK from onyx.prompts.constants import TRIPLE_BACKTICK
@@ -41,164 +40,6 @@ class CitationProcessor:
self.current_citations: list[int] = [] self.current_citations: list[int] = []
self.past_cite_count = 0 self.past_cite_count = 0
# TODO: should reference previous citation processing, rework previous, or completely use new one?
def multi_process_token(
self, parsed_object: ResponsePart
) -> Generator[ResponsePart, None, None]:
# if isinstance(parsed_object,OnyxAnswerPiece):
# # standard citation processing
# yield from self.process_token(parsed_object.answer_piece)
# elif isinstance(parsed_object, AgentAnswerPiece):
# # citation processing for agent answer pieces
# for token in self.process_token(parsed_object.answer_piece):
# if isinstance(token, CitationInfo):
# yield token
# else:
# yield AgentAnswerPiece(answer_piece=token.answer_piece or '',
# answer_type=parsed_object.answer_type, level=parsed_object.level,
# level_question_nr=parsed_object.level_question_nr)
# level = getattr(parsed_object, "level", None)
# level_question_nr = getattr(parsed_object, "level_question_nr", None)
# if isinstance(parsed_object, (AgentAnswerPiece, OnyxAnswerPiece)):
# # logger.debug(f"FA {parsed_object.answer_piece}")
# if isinstance(parsed_object, AgentAnswerPiece):
# token = parsed_object.answer_piece
# level = parsed_object.level
# level_question_nr = parsed_object.level_question_nr
# else:
# yield parsed_object
# return
# # raise ValueError(
# # f"Invalid parsed object type: {type(parsed_object)}"
# # )
# if not citation_potential[level][level_question_nr] and token:
# if token.startswith(" ["):
# citation_potential[level][level_question_nr] = True
# current_yield_components[level][level_question_nr] = [token]
# else:
# yield parsed_object
# elif token and citation_potential[level][level_question_nr]:
# current_yield_components[level][level_question_nr].append(token)
# current_yield_str[level][level_question_nr] = "".join(
# current_yield_components[level][level_question_nr]
# )
# if current_yield_str[level][level_question_nr].strip().startswith(
# "[D"
# ) or current_yield_str[level][level_question_nr].strip().startswith(
# "[Q"
# ):
# citation_potential[level][level_question_nr] = True
# else:
# citation_potential[level][level_question_nr] = False
# parsed_object = _set_combined_token_value(
# current_yield_str[level][level_question_nr], parsed_object
# )
# yield parsed_object
# if (
# len(current_yield_components[level][level_question_nr]) > 15
# ): # ??? 15?
# citation_potential[level][level_question_nr] = False
# parsed_object = _set_combined_token_value(
# current_yield_str[level][level_question_nr], parsed_object
# )
# yield parsed_object
# elif "]" in current_yield_str[level][level_question_nr]:
# section_split = current_yield_str[level][level_question_nr].split(
# "]"
# )
# section_split[0] + "]" # dead code?
# start_of_next_section = "]".join(section_split[1:])
# citation_string = current_yield_str[level][level_question_nr][
# : -len(start_of_next_section)
# ]
# if "[D" in citation_string:
# cite_open_bracket_marker, cite_close_bracket_marker = (
# "[",
# "]",
# )
# cite_identifyer = "D"
# try:
# cited_document = int(
# citation_string[level][level_question_nr][2:-1]
# )
# if level and level_question_nr:
# link = agent_document_citations[int(level)][
# int(level_question_nr)
# ][cited_document].link
# else:
# link = ""
# except (ValueError, IndexError):
# link = ""
# elif "[Q" in citation_string:
# cite_open_bracket_marker, cite_close_bracket_marker = (
# "{",
# "}",
# )
# cite_identifyer = "Q"
# else:
# pass
# citation_string = citation_string.replace(
# "[" + cite_identifyer,
# cite_open_bracket_marker * 2,
# ).replace("]", cite_close_bracket_marker * 2)
# if cite_identifyer == "D":
# citation_string += f"({link})"
# parsed_object = _set_combined_token_value(
# citation_string, parsed_object
# )
# yield parsed_object
# current_yield_components[level][level_question_nr] = [
# start_of_next_section
# ]
# if not start_of_next_section.strip().startswith("["):
# citation_potential[level][level_question_nr] = False
# elif isinstance(parsed_object, ExtendedToolResponse):
# if parsed_object.id == "search_response_summary":
# level = parsed_object.level
# level_question_nr = parsed_object.level_question_nr
# for inference_section in parsed_object.response.top_sections:
# doc_link = inference_section.center_chunk.source_links[0]
# doc_title = inference_section.center_chunk.title
# doc_id = inference_section.center_chunk.document_id
# if (
# doc_id
# not in agent_question_citations_used_docs[level][
# level_question_nr
# ]
# ):
# if level not in agent_document_citations:
# agent_document_citations[level] = {}
# if level_question_nr not in agent_document_citations[level]:
# agent_document_citations[level][level_question_nr] = []
# agent_document_citations[level][level_question_nr].append(
# AgentDocumentCitations(
# document_id=doc_id,
# document_title=doc_title,
# link=doc_link,
# )
# )
# agent_question_citations_used_docs[level][
# level_question_nr
# ].append(doc_id)
yield parsed_object
def process_token( def process_token(
self, token: str | None self, token: str | None
) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]: ) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]:

View File

@@ -41,6 +41,7 @@ DEFAULT_CC_PAIR_ID = 1
# subquestion level and question number for basic flow # subquestion level and question number for basic flow
BASIC_KEY = (-1, -1) BASIC_KEY = (-1, -1)
AGENT_SEARCH_INITIAL_KEY = (0, 0) AGENT_SEARCH_INITIAL_KEY = (0, 0)
CANCEL_CHECK_INTERVAL = 20
# Postgres connection constants for application_name # Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web" POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer" POSTGRES_INDEXER_APP_NAME = "indexer"

View File

@@ -5,7 +5,6 @@ from unittest.mock import MagicMock
import pytest import pytest
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.chat.chat_utils import llm_doc_from_inference_section from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig from onyx.chat.models import CitationConfig
@@ -14,22 +13,17 @@ from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.constants import DocumentSource from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection from onyx.context.search.models import InferenceSection
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig from onyx.llm.interfaces import LLMConfig
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolResponse from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import ( from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID, FINAL_CONTEXT_DOCUMENTS_ID,
) )
from onyx.tools.utils import explicit_tool_calling_supported
QUERY = "Test question" QUERY = "Test question"
DEFAULT_SEARCH_ARGS = {"query": "search"} DEFAULT_SEARCH_ARGS = {"query": "search"}
@@ -40,43 +34,6 @@ def answer_style_config() -> AnswerStyleConfig:
return AnswerStyleConfig(citation_config=CitationConfig()) return AnswerStyleConfig(citation_config=CitationConfig())
@pytest.fixture
def agent_search_config(
mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig
) -> AgentSearchConfig:
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=QUERY,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,
raw_user_uploaded_files=[],
single_message_history=None,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
using_tool_calling_llm = explicit_tool_calling_supported(
mock_llm.config.model_provider, mock_llm.config.model_name
)
return AgentSearchConfig(
search_request=SearchRequest(query=QUERY),
primary_llm=mock_llm,
fast_llm=mock_llm,
search_tool=mock_search_tool,
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
prompt_builder=prompt_builder,
chat_session_id=None,
message_id=1,
use_persistence=True,
db_session=None,
use_agentic_search=False,
using_tool_calling_llm=using_tool_calling_llm,
)
@pytest.fixture @pytest.fixture
def prompt_config() -> PromptConfig: def prompt_config() -> PromptConfig:
return PromptConfig( return PromptConfig(
@@ -89,7 +46,7 @@ def prompt_config() -> PromptConfig:
@pytest.fixture @pytest.fixture
def mock_llm() -> MagicMock: def mock_llm() -> MagicMock:
mock_llm_obj = MagicMock() mock_llm_obj = MagicMock(spec=LLM)
mock_llm_obj.config = LLMConfig( mock_llm_obj.config = LLMConfig(
model_provider="openai", model_provider="openai",
model_name="gpt-4o", model_name="gpt-4o",

View File

@@ -65,6 +65,7 @@ def answer_instance(
search_request=SearchRequest(query=QUERY), search_request=SearchRequest(query=QUERY),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0, current_agent_message_id=0,
use_agentic_persistence=False,
) )

View File

@@ -11,6 +11,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import PromptConfig from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.context.search.models import SearchRequest from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool from onyx.tools.force import ForceUseTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.search.search_tool import SearchTool
from tests.regression.answer_quality.run_qa import _process_and_write_query_results from tests.regression.answer_quality.run_qa import _process_and_write_query_results
@@ -38,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
question = config["question"] question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"] skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
mock_llm = Mock() mock_llm = Mock(spec=LLM)
mock_llm.config = Mock() mock_llm.config = Mock()
mock_llm.config.model_name = "gpt-4o-mini" mock_llm.config.model_name = "gpt-4o-mini"
mock_llm.stream = Mock() mock_llm.stream = Mock()
@@ -66,6 +67,7 @@ def test_skip_gen_ai_answer_generation_flag(
), ),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0, current_agent_message_id=0,
use_agentic_persistence=False,
) )
results = list(answer.processed_streamed_output) results = list(answer.processed_streamed_output)
for res in results: for res in results: