basic search restructure: WIP on fixing tests

This commit is contained in:
Evan Lohn
2025-01-22 16:15:22 -08:00
parent 8aa82be12a
commit dd260140b2
17 changed files with 403 additions and 305 deletions

View File

@@ -1,16 +1,12 @@
from collections import defaultdict
from collections.abc import Callable
from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import ToolCall
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import AgentSearchConfig
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
from onyx.chat.models import AnswerStream
@@ -18,18 +14,9 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.configs.constants import BASIC_KEY
from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
@@ -37,7 +24,6 @@ from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
@@ -52,7 +38,7 @@ class Answer:
llm: LLM,
prompt_config: PromptConfig,
force_use_tool: ForceUseTool,
pro_search_config: AgentSearchConfig,
agent_search_config: AgentSearchConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
message_history: list[PreviousMessage] | None = None,
single_message_history: str | None = None,
@@ -114,7 +100,7 @@ class Answer:
and not skip_explicit_tool_calling
)
self.pro_search_config = pro_search_config
self.agent_search_config = agent_search_config
self.db_session = db_session
def _get_tools_list(self) -> list[Tool]:
@@ -132,130 +118,132 @@ class Answer:
return [tool]
# TODO: delete the function and move the full body to processed_streamed_output
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
current_llm_call = llm_calls[-1]
def _get_response(self) -> AnswerStream:
# current_llm_call = llm_calls[-1]
tool, tool_args = None, None
# handle the case where no decision has to be made; we simply run the tool
if (
current_llm_call.force_use_tool.force_use
and current_llm_call.force_use_tool.args is not None
):
tool_name, tool_args = (
current_llm_call.force_use_tool.tool_name,
current_llm_call.force_use_tool.args,
)
tool = get_tool_by_name(current_llm_call.tools, tool_name)
# tool, tool_args = None, None
# # handle the case where no decision has to be made; we simply run the tool
# if (
# current_llm_call.force_use_tool.force_use
# and current_llm_call.force_use_tool.args is not None
# ):
# tool_name, tool_args = (
# current_llm_call.force_use_tool.tool_name,
# current_llm_call.force_use_tool.args,
# )
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
# special pre-logic for non-tool calling LLM case
elif not self.using_tool_calling_llm and current_llm_call.tools:
chosen_tool_and_args = (
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
current_llm_call, self.llm
)
)
if chosen_tool_and_args:
tool, tool_args = chosen_tool_and_args
# # special pre-logic for non-tool calling LLM case
# elif not self.using_tool_calling_llm and current_llm_call.tools:
# chosen_tool_and_args = (
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
# current_llm_call, self.llm
# )
# )
# if chosen_tool_and_args:
# tool, tool_args = chosen_tool_and_args
if tool and tool_args:
dummy_tool_call_chunk = AIMessageChunk(content="")
dummy_tool_call_chunk.tool_calls = [
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
]
# if tool and tool_args:
# dummy_tool_call_chunk = AIMessageChunk(content="")
# dummy_tool_call_chunk.tool_calls = [
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
# ]
response_handler_manager = LLMResponseHandlerManager(
ToolResponseHandler([tool]), None, self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
# response_handler_manager = LLMResponseHandlerManager(
# ToolResponseHandler([tool]), None, self.is_cancelled
# )
# yield from response_handler_manager.handle_llm_response(
# iter([dummy_tool_call_chunk])
# )
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
if tmp_call is None:
return # no more LLM calls to process
current_llm_call = tmp_call
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
# if tmp_call is None:
# return # no more LLM calls to process
# current_llm_call = tmp_call
# if we're skipping gen ai answer generation, we should break
# out unless we're forcing a tool call. If we don't, we might generate an
# answer, which is a no-no!
if (
self.skip_gen_ai_answer_generation
and not current_llm_call.force_use_tool.force_use
):
return
# # if we're skipping gen ai answer generation, we should break
# # out unless we're forcing a tool call. If we don't, we might generate an
# # answer, which is a no-no!
# if (
# self.skip_gen_ai_answer_generation
# and not current_llm_call.force_use_tool.force_use
# ):
# return
# set up "handlers" to listen to the LLM response stream and
# feed back the processed results + handle tool call requests
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
# # set up "handlers" to listen to the LLM response stream and
# # feed back the processed results + handle tool call requests
# # + figure out what the next LLM call should be
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call
) or ([], [])
# final_search_results, displayed_search_results = SearchTool.get_search_result(
# current_llm_call
# ) or ([], [])
# NEXT: we still want to handle the LLM response stream, but it is now:
# 1. handle the tool call requests
# 2. feed back the processed results
# 3. handle the citations
# # NEXT: we still want to handle the LLM response stream, but it is now:
# # 1. handle the tool call requests
# # 2. feed back the processed results
# # 3. handle the citations
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
)
# answer_handler = CitationResponseHandler(
# context_docs=final_search_results,
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
# )
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
response_handler_manager = LLMResponseHandlerManager(
tool_call_handler, answer_handler, self.is_cancelled
)
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
# response_handler_manager = LLMResponseHandlerManager(
# tool_call_handler, answer_handler, self.is_cancelled
# )
# In langgraph, whether we do the basic thing (call llm stream) or pro search
# is based on a flag in the pro search config
if self.pro_search_config.use_agentic_search:
if self.pro_search_config.search_request is None:
raise ValueError("Search request must be provided for pro search")
if self.db_session is None:
raise ValueError("db_session must be provided for pro search")
if self.fast_llm is None:
raise ValueError("fast_llm must be provided for pro search")
if self.agent_search_config.use_agentic_search:
if (
self.agent_search_config.db_session is None
and self.agent_search_config.use_persistence
):
raise ValueError(
"db_session must be provided for pro search when using persistence"
)
stream = run_main_graph(
config=self.pro_search_config,
config=self.agent_search_config,
)
else:
stream = run_basic_graph(
config=self.pro_search_config,
last_llm_call=current_llm_call,
response_handler_manager=response_handler_manager,
config=self.agent_search_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
return
# DEBUG: good breakpoint
stream = self.llm.stream(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
tool_choice=(
"required"
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
else None
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
# stream = self.llm.stream(
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
# prompt=current_llm_call.prompt_builder.build(),
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
# tool_choice=(
# "required"
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
# else None
# ),
# structured_response_format=self.answer_style_config.structured_response_format,
# )
# yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
# if new_llm_call:
# yield from self._get_response(llm_calls + [new_llm_call])
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -263,33 +251,33 @@ class Answer:
yield from self._processed_stream
return
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
llm_call = LLMCall(
prompt_builder=prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
)
# prompt_builder = AnswerPromptBuilder(
# user_message=default_build_user_message(
# user_query=self.question,
# prompt_config=self.prompt_config,
# files=self.latest_query_files,
# single_message_history=self.single_message_history,
# ),
# message_history=self.message_history,
# llm_config=self.llm.config,
# raw_user_query=self.question,
# raw_user_uploaded_files=self.latest_query_files or [],
# single_message_history=self.single_message_history,
# )
# prompt_builder.update_system_prompt(
# default_build_system_message(self.prompt_config)
# )
# llm_call = LLMCall(
# prompt_builder=prompt_builder,
# tools=self._get_tools_list(),
# force_use_tool=self.force_use_tool,
# files=self.latest_query_files,
# tool_call_info=[],
# using_tool_calling_llm=self.using_tool_calling_llm,
# )
processed_stream = []
for processed_packet in self._get_response([llm_call]):
for processed_packet in self._get_response():
processed_stream.append(processed_packet)
yield processed_packet

View File

@@ -49,6 +49,7 @@ def prepare_chat_message_request(
rerank_settings: RerankingDetails | None,
db_session: Session,
use_agentic_search: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> CreateChatMessageRequest:
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
new_chat_session = create_chat_session(
@@ -74,6 +75,7 @@ def prepare_chat_message_request(
retrieval_options=retrieval_details,
rerank_settings=rerank_settings,
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
)

View File

@@ -33,6 +33,9 @@ from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@@ -130,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -137,7 +141,6 @@ from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -534,11 +537,8 @@ def stream_chat_message_objects(
files = load_all_chat_files(
history_msgs, new_msg_req.file_descriptors, db_session
)
latest_query_files = [
file
for file in files
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
]
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
latest_query_files = [file for file in files if file.file_id in req_file_ids]
if user_message:
attach_files_to_chat_message(
@@ -748,17 +748,42 @@ def stream_chat_message_objects(
# TODO: handle multiple search tools
raise ValueError("Multiple search tools found")
search_tool = search_tools[0]
pro_search_config = AgentSearchConfig(
use_agentic_search=new_msg_req.use_agentic_search,
search_request=search_request,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
using_tool_calling_llm = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
force_use_tool = _get_force_search_settings(new_msg_req, tools)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
prompt_config=prompt_config,
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
agent_search_config = AgentSearchConfig(
search_request=search_request,
primary_llm=llm,
fast_llm=fast_llm,
search_tool=search_tool,
structured_response_format=new_msg_req.structured_response_format,
force_use_tool=force_use_tool,
use_agentic_search=new_msg_req.use_agentic_search,
chat_session_id=chat_session_id,
message_id=reserved_message_id,
use_persistence=True,
allow_refinement=True,
db_session=db_session,
prompt_builder=prompt_builder,
tools=tools,
using_tool_calling_llm=using_tool_calling_llm,
files=latest_query_files,
structured_response_format=new_msg_req.structured_response_format,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
# TODO: add previous messages, answer style config, tools, etc.
@@ -785,9 +810,9 @@ def stream_chat_message_objects(
fast_llm=fast_llm,
message_history=message_history,
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
force_use_tool=force_use_tool,
single_message_history=single_message_history,
pro_search_config=pro_search_config,
agent_search_config=agent_search_config,
db_session=db_session,
)

View File

@@ -84,6 +84,7 @@ class AnswerPromptBuilder:
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
single_message_history: str | None = None,
system_message: SystemMessage | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -108,7 +109,14 @@ class AnswerPromptBuilder:
),
)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(

View File

@@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
@@ -50,56 +51,12 @@ class ToolResponseHandler:
def get_tool_call_for_non_tool_calling_llm(
cls, llm_call: LLMCall, llm: LLM
) -> tuple[Tool, dict] | None:
if llm_call.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
tool_args = (
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(llm_call.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args
return get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool=llm_call.force_use_tool,
tools=llm_call.tools,
prompt_builder=llm_call.prompt_builder,
llm=llm,
)
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
@@ -196,3 +153,61 @@ class ToolResponseHandler:
self.tool_final_result,
],
)
def get_tool_call_for_non_tool_calling_llm_impl(
force_use_tool: ForceUseTool,
tools: list[Tool],
prompt_builder: AnswerPromptBuilder,
llm: LLM,
) -> tuple[Tool, dict] | None:
if force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = get_tool_by_name(tools, force_use_tool.tool_name)
tool_args = (
force_use_tool.args
if force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
return (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=tools,
query=prompt_builder.raw_user_query,
history=prompt_builder.raw_message_history,
llm=llm,
)
available_tools_and_args = [
(tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=prompt_builder.raw_message_history,
query=prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
return chosen_tool_and_args