mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
first pass at dead code deletion
This commit is contained in:
parent
3d99ad7bc4
commit
6c7f8eaefb
@ -80,7 +80,7 @@ def agent_logging(state: MainState, config: RunnableConfig) -> MainOutput:
|
||||
agent_metrics=combined_agent_metrics,
|
||||
)
|
||||
|
||||
if agent_a_config.use_persistence:
|
||||
if agent_a_config.use_agentic_persistence:
|
||||
# Persist the sub-answer in the database
|
||||
db_session = agent_a_config.db_session
|
||||
chat_session_id = agent_a_config.chat_session_id
|
||||
|
@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -15,8 +14,7 @@ from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSearchConfig:
|
||||
class AgentSearchConfig(BaseModel):
|
||||
"""
|
||||
Configuration for the Agent Search feature.
|
||||
"""
|
||||
@ -47,10 +45,10 @@ class AgentSearchConfig:
|
||||
# The message ID of the user message that triggered the Pro Search
|
||||
message_id: int | None = None
|
||||
|
||||
# Whether to persistence data for the Pro Search (turned off for testing)
|
||||
use_persistence: bool = True
|
||||
# Whether to persistence data for Agentic Search (turned off for testing)
|
||||
use_agentic_persistence: bool = True
|
||||
|
||||
# The database session for the Pro Search
|
||||
# The database session for Agentic Search
|
||||
db_session: Session | None = None
|
||||
|
||||
# Whether to perform initial search to inform decomposition
|
||||
@ -75,7 +73,7 @@ class AgentSearchConfig:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_persistence and self.db_session is None:
|
||||
if self.use_agentic_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
@ -87,6 +85,9 @@ class AgentSearchConfig:
|
||||
raise ValueError("search_tool must be provided for agentic search")
|
||||
return self
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
|
@ -38,7 +38,6 @@ def tool_call(state: ToolChoiceUpdate, config: RunnableConfig) -> ToolCallUpdate
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
# TODO: custom events for yields
|
||||
emit_packet(tool_kickoff)
|
||||
|
||||
tool_responses = []
|
||||
|
@ -155,26 +155,19 @@ def run_graph(
|
||||
|
||||
# TODO: call this once on startup, TBD where and if it should be gated based
|
||||
# on dev mode or not
|
||||
def load_compiled_graph(graph_name: str) -> CompiledStateGraph:
|
||||
main_graph_builder = (
|
||||
main_graph_builder_a if graph_name == "a" else main_graph_builder_a
|
||||
)
|
||||
def load_compiled_graph() -> CompiledStateGraph:
|
||||
global _COMPILED_GRAPH
|
||||
if _COMPILED_GRAPH is None:
|
||||
graph = main_graph_builder()
|
||||
graph = main_graph_builder_a()
|
||||
_COMPILED_GRAPH = graph.compile()
|
||||
return _COMPILED_GRAPH
|
||||
|
||||
|
||||
def run_main_graph(
|
||||
config: AgentSearchConfig,
|
||||
graph_name: str = "a",
|
||||
) -> AnswerStream:
|
||||
compiled_graph = load_compiled_graph(graph_name)
|
||||
if graph_name == "a":
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
else:
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
compiled_graph = load_compiled_graph()
|
||||
input = MainInput_a(base_question=config.search_request.query, log_messages=[])
|
||||
|
||||
# Agent search is not a Tool per se, but this is helpful for the frontend
|
||||
yield ToolCallKickoff(
|
||||
@ -202,11 +195,36 @@ if __name__ == "__main__":
|
||||
now_start = datetime.now()
|
||||
logger.debug(f"Start at {now_start}")
|
||||
|
||||
if GRAPH_VERSION_NAME == "a":
|
||||
graph = main_graph_builder_a()
|
||||
else:
|
||||
graph = main_graph_builder_a()
|
||||
compiled_graph = graph.compile()
|
||||
graph = main_graph_builder_a()
|
||||
now_start = datetime.now()
|
||||
compiled_graph = graph.compile()
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
# query="what can you do with gitlab?",
|
||||
# query="What are the guiding principles behind the development of cockroachDB",
|
||||
# query="What are the temperatures in Munich, Hawaii, and New York?",
|
||||
# query="When was Washington born?",
|
||||
# query="What is Onyx?",
|
||||
# query="What is the difference between astronomy and astrology?",
|
||||
query="Do a search to tell me what is the difference between astronomy and astrology?",
|
||||
)
|
||||
# Joachim custom persona
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
config, search_tool = get_test_config(
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_agentic_persistence = True
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
|
||||
input = MainInput_a(
|
||||
base_question=config.search_request.query, log_messages=[]
|
||||
)
|
||||
|
||||
now_end = datetime.now()
|
||||
logger.debug(f"Graph compiled in {now_end - now_start} seconds")
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
@ -226,7 +244,7 @@ if __name__ == "__main__":
|
||||
db_session, primary_llm, fast_llm, search_request
|
||||
)
|
||||
# search_request.persona = get_persona_by_id(1, None, db_session)
|
||||
config.use_persistence = True
|
||||
config.use_agentic_persistence = True
|
||||
# config.perform_initial_search_path_decision = False
|
||||
config.perform_initial_search_decomposition = True
|
||||
if GRAPH_VERSION_NAME == "a":
|
||||
|
@ -248,7 +248,7 @@ def get_test_config(
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
use_agentic_persistence=True,
|
||||
db_session=db_session,
|
||||
tools=[search_tool],
|
||||
use_agentic_search=use_agentic_search,
|
||||
|
@ -54,6 +54,7 @@ class Answer:
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
db_session: Session | None = None,
|
||||
use_agentic_search: bool = False,
|
||||
use_agentic_persistence: bool = True,
|
||||
) -> None:
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
@ -90,14 +91,9 @@ class Answer:
|
||||
|
||||
if len(search_tools) > 1:
|
||||
# TODO: handle multiple search tools
|
||||
logger.warning("Multiple search tools found, using first one")
|
||||
search_tool = search_tools[0]
|
||||
raise ValueError("Multiple search tools found")
|
||||
elif len(search_tools) == 1:
|
||||
search_tool = search_tools[0]
|
||||
else:
|
||||
logger.warning("No search tool found")
|
||||
if use_agentic_search:
|
||||
raise ValueError("No search tool found, cannot use agentic search")
|
||||
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
@ -111,7 +107,7 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=current_agent_message_id,
|
||||
use_persistence=True,
|
||||
use_agentic_persistence=use_agentic_persistence,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
@ -137,104 +133,20 @@ class Answer:
|
||||
logger.info(f"Forcefully using tool='{tool.name}'{args_str}")
|
||||
return [tool]
|
||||
|
||||
# TODO: delete the function and move the full body to processed_streamed_output
|
||||
def _get_response(self) -> AnswerStream:
|
||||
# current_llm_call = llm_calls[-1]
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
# tool, tool_args = None, None
|
||||
# # handle the case where no decision has to be made; we simply run the tool
|
||||
# if (
|
||||
# current_llm_call.force_use_tool.force_use
|
||||
# and current_llm_call.force_use_tool.args is not None
|
||||
# ):
|
||||
# tool_name, tool_args = (
|
||||
# current_llm_call.force_use_tool.tool_name,
|
||||
# current_llm_call.force_use_tool.args,
|
||||
# )
|
||||
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
|
||||
# # special pre-logic for non-tool calling LLM case
|
||||
# elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
# chosen_tool_and_args = (
|
||||
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
# current_llm_call, self.llm
|
||||
# )
|
||||
# )
|
||||
# if chosen_tool_and_args:
|
||||
# tool, tool_args = chosen_tool_and_args
|
||||
|
||||
# if tool and tool_args:
|
||||
# dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
# dummy_tool_call_chunk.tool_calls = [
|
||||
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
# ]
|
||||
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(
|
||||
# iter([dummy_tool_call_chunk])
|
||||
# )
|
||||
|
||||
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if tmp_call is None:
|
||||
# return # no more LLM calls to process
|
||||
# current_llm_call = tmp_call
|
||||
|
||||
# # if we're skipping gen ai answer generation, we should break
|
||||
# # out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# # answer, which is a no-no!
|
||||
# if (
|
||||
# self.skip_gen_ai_answer_generation
|
||||
# and not current_llm_call.force_use_tool.force_use
|
||||
# ):
|
||||
# return
|
||||
|
||||
# # set up "handlers" to listen to the LLM response stream and
|
||||
# # feed back the processed results + handle tool call requests
|
||||
# # + figure out what the next LLM call should be
|
||||
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
# final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
# current_llm_call
|
||||
# ) or ([], [])
|
||||
|
||||
# # NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# # 1. handle the tool call requests
|
||||
# # 2. feed back the processed results
|
||||
# # 3. handle the citations
|
||||
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# tool_call_handler, answer_handler, self.is_cancelled
|
||||
# )
|
||||
|
||||
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
||||
# is based on a flag in the pro search config
|
||||
|
||||
if self.agent_search_config.use_agentic_search:
|
||||
if (
|
||||
self.agent_search_config.db_session is None
|
||||
and self.agent_search_config.use_persistence
|
||||
):
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
|
||||
stream = run_main_graph(
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
else:
|
||||
stream = run_basic_graph(
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
run_langgraph = (
|
||||
run_main_graph
|
||||
if self.agent_search_config.use_agentic_search
|
||||
else run_basic_graph
|
||||
)
|
||||
stream = run_langgraph(
|
||||
self.agent_search_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
@ -244,62 +156,6 @@ class Answer:
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
return
|
||||
# DEBUG: good breakpoint
|
||||
# stream = self.llm.stream(
|
||||
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
# prompt=current_llm_call.prompt_builder.build(),
|
||||
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
# tool_choice=(
|
||||
# "required"
|
||||
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
# else None
|
||||
# ),
|
||||
# structured_response_format=self.answer_style_config.structured_response_format,
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if new_llm_call:
|
||||
# yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
# prompt_builder = AnswerPromptBuilder(
|
||||
# user_message=default_build_user_message(
|
||||
# user_query=self.question,
|
||||
# prompt_config=self.prompt_config,
|
||||
# files=self.latest_query_files,
|
||||
# single_message_history=self.single_message_history,
|
||||
# ),
|
||||
# message_history=self.message_history,
|
||||
# llm_config=self.llm.config,
|
||||
# raw_user_query=self.question,
|
||||
# raw_user_uploaded_files=self.latest_query_files or [],
|
||||
# single_message_history=self.single_message_history,
|
||||
# )
|
||||
# prompt_builder.update_system_prompt(
|
||||
# default_build_system_message(self.prompt_config)
|
||||
# )
|
||||
# llm_call = LLMCall(
|
||||
# prompt_builder=prompt_builder,
|
||||
# tools=self._get_tools_list(),
|
||||
# force_use_tool=self.force_use_tool,
|
||||
# files=self.latest_query_files,
|
||||
# tool_call_info=[],
|
||||
# using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
# )
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response():
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@ -343,6 +199,7 @@ class Answer:
|
||||
|
||||
return citations
|
||||
|
||||
# TODO: replace tuple of ints with SubQuestionId EVERYWHERE
|
||||
def citations_by_subquestion(self) -> dict[tuple[int, int], list[CitationInfo]]:
|
||||
citations_by_subquestion: dict[
|
||||
tuple[int, int], list[CitationInfo]
|
||||
|
@ -1,7 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
@ -26,10 +24,6 @@ class AnswerResponseHandler(abc.ABC):
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, state_update: Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PassThroughAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
@ -40,9 +34,6 @@ class PassThroughAnswerResponseHandler(AnswerResponseHandler):
|
||||
content = _message_to_str(response_item)
|
||||
yield OnyxAnswerPiece(answer_piece=content)
|
||||
|
||||
def update(self, state_update: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
@ -53,9 +44,6 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
def update(self, state_update: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
@ -91,20 +79,6 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
def update(self, state_update: Any) -> None:
|
||||
state = cast(
|
||||
tuple[list[LlmDoc], DocumentIdOrderMapping, DocumentIdOrderMapping],
|
||||
state_update,
|
||||
)
|
||||
self.context_docs = state[0]
|
||||
self.final_doc_id_to_rank_map = state[1]
|
||||
self.display_doc_id_to_rank_map = state[2]
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
)
|
||||
|
||||
|
||||
def _message_to_str(message: BaseMessage | str | None) -> str:
|
||||
if message is None:
|
||||
@ -116,80 +90,3 @@ def _message_to_str(message: BaseMessage | str | None) -> str:
|
||||
logger.warning(f"Received non-string content: {type(content)}")
|
||||
content = str(content) if content is not None else ""
|
||||
return content
|
||||
|
||||
|
||||
# class CitationMultiResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(self) -> None:
|
||||
# self.channel_processors: dict[str, CitationProcessor] = {}
|
||||
# self._default_channel = "__default__"
|
||||
|
||||
# def register_default_channel(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
# ) -> None:
|
||||
# """Register the default channel with its associated documents and ranking maps."""
|
||||
# self.register_channel(
|
||||
# channel_id=self._default_channel,
|
||||
# context_docs=context_docs,
|
||||
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
# )
|
||||
|
||||
# def register_channel(
|
||||
# self,
|
||||
# channel_id: str,
|
||||
# context_docs: list[LlmDoc],
|
||||
# final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
# display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
# ) -> None:
|
||||
# """Register a new channel with its associated documents and ranking maps."""
|
||||
# self.channel_processors[channel_id] = CitationProcessor(
|
||||
# context_docs=context_docs,
|
||||
# final_doc_id_to_rank_map=final_doc_id_to_rank_map,
|
||||
# display_doc_id_to_rank_map=display_doc_id_to_rank_map,
|
||||
# )
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | str | None,
|
||||
# previous_response_items: list[BaseMessage | str],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# """Default implementation that uses the default channel."""
|
||||
|
||||
# yield from self.handle_channel_response(
|
||||
# response_item=content,
|
||||
# previous_response_items=previous_response_items,
|
||||
# channel_id=self._default_channel,
|
||||
# )
|
||||
|
||||
# def handle_channel_response(
|
||||
# self,
|
||||
# response_item: ResponsePart | str | None,
|
||||
# previous_response_items: list[ResponsePart | str],
|
||||
# channel_id: str,
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# """Process a response part for a specific channel."""
|
||||
# if channel_id not in self.channel_processors:
|
||||
# raise ValueError(f"Attempted to process response for unregistered channel {channel_id}")
|
||||
|
||||
# if response_item is None:
|
||||
# return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item, BaseMessage) else response_item
|
||||
# )
|
||||
|
||||
# # Ensure content is a string
|
||||
# if not isinstance(content, str):
|
||||
# logger.warning(f"Received non-string content: {type(content)}")
|
||||
# content = str(content) if content is not None else ""
|
||||
|
||||
# # Process the new content through the channel's citation processor
|
||||
# yield from self.channel_processors[channel_id].multi_process_token(content)
|
||||
|
||||
# def remove_channel(self, channel_id: str) -> None:
|
||||
# """Remove a channel and its associated processor."""
|
||||
# if channel_id in self.channel_processors:
|
||||
# del self.channel_processors[channel_id]
|
||||
|
@ -4,7 +4,6 @@ from collections.abc import Generator
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
from onyx.prompts.constants import TRIPLE_BACKTICK
|
||||
@ -41,164 +40,6 @@ class CitationProcessor:
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
|
||||
# TODO: should reference previous citation processing, rework previous, or completely use new one?
|
||||
def multi_process_token(
|
||||
self, parsed_object: ResponsePart
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# if isinstance(parsed_object,OnyxAnswerPiece):
|
||||
# # standard citation processing
|
||||
# yield from self.process_token(parsed_object.answer_piece)
|
||||
|
||||
# elif isinstance(parsed_object, AgentAnswerPiece):
|
||||
# # citation processing for agent answer pieces
|
||||
# for token in self.process_token(parsed_object.answer_piece):
|
||||
# if isinstance(token, CitationInfo):
|
||||
# yield token
|
||||
# else:
|
||||
# yield AgentAnswerPiece(answer_piece=token.answer_piece or '',
|
||||
# answer_type=parsed_object.answer_type, level=parsed_object.level,
|
||||
# level_question_nr=parsed_object.level_question_nr)
|
||||
|
||||
# level = getattr(parsed_object, "level", None)
|
||||
# level_question_nr = getattr(parsed_object, "level_question_nr", None)
|
||||
|
||||
# if isinstance(parsed_object, (AgentAnswerPiece, OnyxAnswerPiece)):
|
||||
# # logger.debug(f"FA {parsed_object.answer_piece}")
|
||||
# if isinstance(parsed_object, AgentAnswerPiece):
|
||||
# token = parsed_object.answer_piece
|
||||
# level = parsed_object.level
|
||||
# level_question_nr = parsed_object.level_question_nr
|
||||
# else:
|
||||
# yield parsed_object
|
||||
# return
|
||||
# # raise ValueError(
|
||||
# # f"Invalid parsed object type: {type(parsed_object)}"
|
||||
# # )
|
||||
|
||||
# if not citation_potential[level][level_question_nr] and token:
|
||||
# if token.startswith(" ["):
|
||||
# citation_potential[level][level_question_nr] = True
|
||||
# current_yield_components[level][level_question_nr] = [token]
|
||||
# else:
|
||||
# yield parsed_object
|
||||
# elif token and citation_potential[level][level_question_nr]:
|
||||
# current_yield_components[level][level_question_nr].append(token)
|
||||
# current_yield_str[level][level_question_nr] = "".join(
|
||||
# current_yield_components[level][level_question_nr]
|
||||
# )
|
||||
|
||||
# if current_yield_str[level][level_question_nr].strip().startswith(
|
||||
# "[D"
|
||||
# ) or current_yield_str[level][level_question_nr].strip().startswith(
|
||||
# "[Q"
|
||||
# ):
|
||||
# citation_potential[level][level_question_nr] = True
|
||||
|
||||
# else:
|
||||
# citation_potential[level][level_question_nr] = False
|
||||
# parsed_object = _set_combined_token_value(
|
||||
# current_yield_str[level][level_question_nr], parsed_object
|
||||
# )
|
||||
# yield parsed_object
|
||||
|
||||
# if (
|
||||
# len(current_yield_components[level][level_question_nr]) > 15
|
||||
# ): # ??? 15?
|
||||
# citation_potential[level][level_question_nr] = False
|
||||
# parsed_object = _set_combined_token_value(
|
||||
# current_yield_str[level][level_question_nr], parsed_object
|
||||
# )
|
||||
# yield parsed_object
|
||||
# elif "]" in current_yield_str[level][level_question_nr]:
|
||||
# section_split = current_yield_str[level][level_question_nr].split(
|
||||
# "]"
|
||||
# )
|
||||
# section_split[0] + "]" # dead code?
|
||||
# start_of_next_section = "]".join(section_split[1:])
|
||||
# citation_string = current_yield_str[level][level_question_nr][
|
||||
# : -len(start_of_next_section)
|
||||
# ]
|
||||
# if "[D" in citation_string:
|
||||
# cite_open_bracket_marker, cite_close_bracket_marker = (
|
||||
# "[",
|
||||
# "]",
|
||||
# )
|
||||
# cite_identifyer = "D"
|
||||
|
||||
# try:
|
||||
# cited_document = int(
|
||||
# citation_string[level][level_question_nr][2:-1]
|
||||
# )
|
||||
# if level and level_question_nr:
|
||||
# link = agent_document_citations[int(level)][
|
||||
# int(level_question_nr)
|
||||
# ][cited_document].link
|
||||
# else:
|
||||
# link = ""
|
||||
# except (ValueError, IndexError):
|
||||
# link = ""
|
||||
# elif "[Q" in citation_string:
|
||||
# cite_open_bracket_marker, cite_close_bracket_marker = (
|
||||
# "{",
|
||||
# "}",
|
||||
# )
|
||||
# cite_identifyer = "Q"
|
||||
# else:
|
||||
# pass
|
||||
|
||||
# citation_string = citation_string.replace(
|
||||
# "[" + cite_identifyer,
|
||||
# cite_open_bracket_marker * 2,
|
||||
# ).replace("]", cite_close_bracket_marker * 2)
|
||||
|
||||
# if cite_identifyer == "D":
|
||||
# citation_string += f"({link})"
|
||||
|
||||
# parsed_object = _set_combined_token_value(
|
||||
# citation_string, parsed_object
|
||||
# )
|
||||
|
||||
# yield parsed_object
|
||||
|
||||
# current_yield_components[level][level_question_nr] = [
|
||||
# start_of_next_section
|
||||
# ]
|
||||
# if not start_of_next_section.strip().startswith("["):
|
||||
# citation_potential[level][level_question_nr] = False
|
||||
|
||||
# elif isinstance(parsed_object, ExtendedToolResponse):
|
||||
# if parsed_object.id == "search_response_summary":
|
||||
# level = parsed_object.level
|
||||
# level_question_nr = parsed_object.level_question_nr
|
||||
# for inference_section in parsed_object.response.top_sections:
|
||||
# doc_link = inference_section.center_chunk.source_links[0]
|
||||
# doc_title = inference_section.center_chunk.title
|
||||
# doc_id = inference_section.center_chunk.document_id
|
||||
|
||||
# if (
|
||||
# doc_id
|
||||
# not in agent_question_citations_used_docs[level][
|
||||
# level_question_nr
|
||||
# ]
|
||||
# ):
|
||||
# if level not in agent_document_citations:
|
||||
# agent_document_citations[level] = {}
|
||||
# if level_question_nr not in agent_document_citations[level]:
|
||||
# agent_document_citations[level][level_question_nr] = []
|
||||
|
||||
# agent_document_citations[level][level_question_nr].append(
|
||||
# AgentDocumentCitations(
|
||||
# document_id=doc_id,
|
||||
# document_title=doc_title,
|
||||
# link=doc_link,
|
||||
# )
|
||||
# )
|
||||
# agent_question_citations_used_docs[level][
|
||||
# level_question_nr
|
||||
# ].append(doc_id)
|
||||
|
||||
yield parsed_object
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]:
|
||||
|
@ -41,6 +41,7 @@ DEFAULT_CC_PAIR_ID = 1
|
||||
# subquestion level and question number for basic flow
|
||||
BASIC_KEY = (-1, -1)
|
||||
AGENT_SEARCH_INITIAL_KEY = (0, 0)
|
||||
CANCEL_CHECK_INTERVAL = 20
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
|
@ -5,7 +5,6 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
@ -14,22 +13,17 @@ from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
|
||||
QUERY = "Test question"
|
||||
DEFAULT_SEARCH_ARGS = {"query": "search"}
|
||||
@ -40,43 +34,6 @@ def answer_style_config() -> AnswerStyleConfig:
|
||||
return AnswerStyleConfig(citation_config=CitationConfig())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_search_config(
|
||||
mock_llm: LLM, mock_search_tool: SearchTool, prompt_config: PromptConfig
|
||||
) -> AgentSearchConfig:
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=QUERY,
|
||||
prompt_config=prompt_config,
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
message_history=[],
|
||||
llm_config=mock_llm.config,
|
||||
raw_user_query=QUERY,
|
||||
raw_user_uploaded_files=[],
|
||||
single_message_history=None,
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
mock_llm.config.model_provider, mock_llm.config.model_name
|
||||
)
|
||||
return AgentSearchConfig(
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
primary_llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
search_tool=mock_search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=prompt_builder,
|
||||
chat_session_id=None,
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
db_session=None,
|
||||
use_agentic_search=False,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_config() -> PromptConfig:
|
||||
return PromptConfig(
|
||||
@ -89,7 +46,7 @@ def prompt_config() -> PromptConfig:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm() -> MagicMock:
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj = MagicMock(spec=LLM)
|
||||
mock_llm_obj.config = LLMConfig(
|
||||
model_provider="openai",
|
||||
model_name="gpt-4o",
|
||||
|
@ -65,6 +65,7 @@ def answer_instance(
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
use_agentic_persistence=False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||
@ -38,7 +39,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.config = Mock()
|
||||
mock_llm.config.model_name = "gpt-4o-mini"
|
||||
mock_llm.stream = Mock()
|
||||
@ -66,6 +67,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
use_agentic_persistence=False,
|
||||
)
|
||||
results = list(answer.processed_streamed_output)
|
||||
for res in results:
|
||||
|
Loading…
x
Reference in New Issue
Block a user