mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 21:32:36 +01:00
basic search restructure: WIP on fixing tests
This commit is contained in:
parent
8aa82be12a
commit
dd260140b2
@ -197,6 +197,7 @@ def get_answer_stream(
|
||||
rerank_settings=query_request.rerank_settings,
|
||||
db_session=db_session,
|
||||
use_agentic_search=query_request.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=query_request.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
@ -1,20 +1,18 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.callbacks.manager import dispatch_custom_event
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agents.agent_search.basic.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.basic.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.basic.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.states import BasicStateUpdate
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def basic_graph_builder() -> StateGraph:
|
||||
@ -27,17 +25,33 @@ def basic_graph_builder() -> StateGraph:
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="get_response",
|
||||
action=get_response,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="basic_use_tool_response",
|
||||
action=basic_use_tool_response,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="get_response")
|
||||
graph.add_edge(start_key=START, end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_conditional_edges("get_response", should_continue, ["get_response", END])
|
||||
graph.add_edge(
|
||||
start_key="get_response",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="basic_use_tool_response",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
@ -46,56 +60,10 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
def should_continue(state: BasicState) -> str:
|
||||
return (
|
||||
END if state["last_llm_call"] is None or state["calls"] > 1 else "get_response"
|
||||
)
|
||||
|
||||
|
||||
def get_response(state: BasicState, config: RunnableConfig) -> BasicStateUpdate:
|
||||
agent_a_config = cast(AgentSearchConfig, config["metadata"]["config"])
|
||||
llm = agent_a_config.primary_llm
|
||||
current_llm_call = state["last_llm_call"]
|
||||
if current_llm_call is None:
|
||||
raise ValueError("last_llm_call is None")
|
||||
structured_response_format = agent_a_config.structured_response_format
|
||||
response_handler_manager = state["response_handler_manager"]
|
||||
# DEBUG: good breakpoint
|
||||
stream = llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=structured_response_format,
|
||||
)
|
||||
|
||||
for response in response_handler_manager.handle_llm_response(stream):
|
||||
dispatch_custom_event(
|
||||
"basic_response",
|
||||
response,
|
||||
)
|
||||
|
||||
next_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if next_call is not None:
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
next_call
|
||||
) or ([], [])
|
||||
else:
|
||||
final_search_results, displayed_search_results = [], []
|
||||
|
||||
response_handler_manager.answer_handler.update(
|
||||
(
|
||||
final_search_results,
|
||||
map_document_id_order(final_search_results),
|
||||
map_document_id_order(displayed_search_results),
|
||||
)
|
||||
)
|
||||
return BasicStateUpdate(
|
||||
last_llm_call=next_call,
|
||||
calls=state["calls"] + 1,
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state["tool_choice"] is None
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,20 +1,21 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
|
||||
|
||||
## Update States
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
|
||||
# States contain values that change over the course of graph execution,
|
||||
# Config is for values that are set at the start and never change.
|
||||
# If you are using a value from the config and realize it needs to change,
|
||||
# you should add it to the state and use/update the version in the state.
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BasicInput(TypedDict):
|
||||
base_question: str
|
||||
last_llm_call: LLMCall | None
|
||||
response_handler_manager: LLMResponseHandlerManager
|
||||
calls: int
|
||||
should_stream_answer: bool
|
||||
|
||||
|
||||
## Graph Output State
|
||||
@ -24,9 +25,22 @@ class BasicOutput(TypedDict):
|
||||
pass
|
||||
|
||||
|
||||
class BasicStateUpdate(TypedDict):
|
||||
last_llm_call: LLMCall | None
|
||||
calls: int
|
||||
## Update States
|
||||
class ToolCallUpdate(TypedDict):
|
||||
tool_call_summary: ToolCallSummary
|
||||
tool_call_kickoff: ToolCallKickoff
|
||||
tool_call_responses: list[ToolResponse]
|
||||
tool_call_final_result: ToolCallFinalResult
|
||||
|
||||
|
||||
class ToolChoice(TypedDict):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
|
||||
|
||||
class ToolChoiceUpdate(TypedDict):
|
||||
tool_choice: ToolChoice | None
|
||||
|
||||
|
||||
## Graph State
|
||||
@ -34,6 +48,8 @@ class BasicStateUpdate(TypedDict):
|
||||
|
||||
class BasicState(
|
||||
BasicInput,
|
||||
ToolCallUpdate,
|
||||
ToolChoiceUpdate,
|
||||
BasicOutput,
|
||||
):
|
||||
pass
|
||||
|
@ -2,11 +2,15 @@ from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
@ -22,7 +26,18 @@ class AgentSearchConfig:
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
search_tool: SearchTool
|
||||
use_agentic_search: bool = True
|
||||
|
||||
# Whether to force use of a tool, or to
|
||||
# force tool args IF the tool is used
|
||||
force_use_tool: ForceUseTool
|
||||
|
||||
# contains message history for the current chat session
|
||||
# has the following (at most one is non-None)
|
||||
# message_history: list[PreviousMessage] | None = None
|
||||
# single_message_history: str | None = None
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
|
||||
use_agentic_search: bool = False
|
||||
|
||||
# For persisting agent search data
|
||||
chat_session_id: UUID | None = None
|
||||
@ -45,11 +60,25 @@ class AgentSearchConfig:
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
|
||||
# Message history for the current chat session
|
||||
message_history: list[PreviousMessage] | None = None
|
||||
# Tools available for use
|
||||
tools: list[Tool] | None = None
|
||||
|
||||
using_tool_calling_llm: bool = False
|
||||
|
||||
files: list[InMemoryChatFile] | None = None
|
||||
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_db_session(self) -> "AgentSearchConfig":
|
||||
if self.use_persistence and self.db_session is None:
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class AgentDocumentCitations(BaseModel):
|
||||
document_id: str
|
||||
|
@ -16,7 +16,6 @@ from onyx.agents.agent_search.deep_search_a.main.graph_builder import (
|
||||
from onyx.agents.agent_search.deep_search_a.main.states import MainInput as MainInput_a
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
@ -25,7 +24,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.models import ToolResponse
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.configs.dev_configs import GRAPH_NAME
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@ -143,7 +141,6 @@ def run_graph(
|
||||
config: AgentSearchConfig,
|
||||
input: BasicInput | MainInput_a,
|
||||
) -> AnswerStream:
|
||||
input["base_question"] = config.search_request.query if config else ""
|
||||
# TODO: add these to the environment
|
||||
config.perform_initial_search_path_decision = True
|
||||
config.perform_initial_search_decomposition = True
|
||||
@ -192,17 +189,12 @@ def run_main_graph(
|
||||
# TODO: unify input types, especially prosearchconfig
|
||||
def run_basic_graph(
|
||||
config: AgentSearchConfig,
|
||||
last_llm_call: LLMCall | None,
|
||||
response_handler_manager: LLMResponseHandlerManager,
|
||||
) -> AnswerStream:
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
# TODO: unify basic input
|
||||
input = BasicInput(
|
||||
base_question="",
|
||||
last_llm_call=last_llm_call,
|
||||
response_handler_manager=response_handler_manager,
|
||||
calls=0,
|
||||
should_stream_answer=True,
|
||||
)
|
||||
return run_graph(compiled_graph, config, input)
|
||||
|
||||
|
@ -11,6 +11,7 @@ from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
@ -21,6 +22,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@ -31,6 +33,7 @@ from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.persona import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
@ -207,14 +210,23 @@ def get_test_config(
|
||||
|
||||
config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=search_request.query),
|
||||
message_history=[],
|
||||
llm_config=primary_llm.config,
|
||||
raw_user_query=search_request.query,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
# chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
chat_session_id=UUID("edda10d5-6cef-45d8-acfb-39317552a1f4"), # Joachim
|
||||
# chat_session_id=UUID("d1acd613-2692-4bc3-9d65-c6d3da62e58e"), # Evan
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return config, search_tool
|
||||
|
@ -1,16 +1,12 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import run_main_graph
|
||||
from onyx.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AnswerPacket
|
||||
from onyx.chat.models import AnswerStream
|
||||
@ -18,18 +14,9 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@ -37,7 +24,6 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@ -52,7 +38,7 @@ class Answer:
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
pro_search_config: AgentSearchConfig,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@ -114,7 +100,7 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
self.pro_search_config = pro_search_config
|
||||
self.agent_search_config = agent_search_config
|
||||
self.db_session = db_session
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
@ -132,130 +118,132 @@ class Answer:
|
||||
return [tool]
|
||||
|
||||
# TODO: delete the function and move the full body to processed_streamed_output
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
def _get_response(self) -> AnswerStream:
|
||||
# current_llm_call = llm_calls[-1]
|
||||
|
||||
tool, tool_args = None, None
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
if (
|
||||
current_llm_call.force_use_tool.force_use
|
||||
and current_llm_call.force_use_tool.args is not None
|
||||
):
|
||||
tool_name, tool_args = (
|
||||
current_llm_call.force_use_tool.tool_name,
|
||||
current_llm_call.force_use_tool.args,
|
||||
)
|
||||
tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
# tool, tool_args = None, None
|
||||
# # handle the case where no decision has to be made; we simply run the tool
|
||||
# if (
|
||||
# current_llm_call.force_use_tool.force_use
|
||||
# and current_llm_call.force_use_tool.args is not None
|
||||
# ):
|
||||
# tool_name, tool_args = (
|
||||
# current_llm_call.force_use_tool.tool_name,
|
||||
# current_llm_call.force_use_tool.args,
|
||||
# )
|
||||
# tool = get_tool_by_name(current_llm_call.tools, tool_name)
|
||||
|
||||
# special pre-logic for non-tool calling LLM case
|
||||
elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
chosen_tool_and_args = (
|
||||
ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
current_llm_call, self.llm
|
||||
)
|
||||
)
|
||||
if chosen_tool_and_args:
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
# # special pre-logic for non-tool calling LLM case
|
||||
# elif not self.using_tool_calling_llm and current_llm_call.tools:
|
||||
# chosen_tool_and_args = (
|
||||
# ToolResponseHandler.get_tool_call_for_non_tool_calling_llm(
|
||||
# current_llm_call, self.llm
|
||||
# )
|
||||
# )
|
||||
# if chosen_tool_and_args:
|
||||
# tool, tool_args = chosen_tool_and_args
|
||||
|
||||
if tool and tool_args:
|
||||
dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
dummy_tool_call_chunk.tool_calls = [
|
||||
ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
]
|
||||
# if tool and tool_args:
|
||||
# dummy_tool_call_chunk = AIMessageChunk(content="")
|
||||
# dummy_tool_call_chunk.tool_calls = [
|
||||
# ToolCall(name=tool.name, args=tool_args, id=str(uuid4()))
|
||||
# ]
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# ToolResponseHandler([tool]), None, self.is_cancelled
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(
|
||||
# iter([dummy_tool_call_chunk])
|
||||
# )
|
||||
|
||||
tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if tmp_call is None:
|
||||
return # no more LLM calls to process
|
||||
current_llm_call = tmp_call
|
||||
# tmp_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if tmp_call is None:
|
||||
# return # no more LLM calls to process
|
||||
# current_llm_call = tmp_call
|
||||
|
||||
# if we're skipping gen ai answer generation, we should break
|
||||
# out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# answer, which is a no-no!
|
||||
if (
|
||||
self.skip_gen_ai_answer_generation
|
||||
and not current_llm_call.force_use_tool.force_use
|
||||
):
|
||||
return
|
||||
# # if we're skipping gen ai answer generation, we should break
|
||||
# # out unless we're forcing a tool call. If we don't, we might generate an
|
||||
# # answer, which is a no-no!
|
||||
# if (
|
||||
# self.skip_gen_ai_answer_generation
|
||||
# and not current_llm_call.force_use_tool.force_use
|
||||
# ):
|
||||
# return
|
||||
|
||||
# set up "handlers" to listen to the LLM response stream and
|
||||
# feed back the processed results + handle tool call requests
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
# # set up "handlers" to listen to the LLM response stream and
|
||||
# # feed back the processed results + handle tool call requests
|
||||
# # + figure out what the next LLM call should be
|
||||
# tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
# final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
# current_llm_call
|
||||
# ) or ([], [])
|
||||
|
||||
# NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# 1. handle the tool call requests
|
||||
# 2. feed back the processed results
|
||||
# 3. handle the citations
|
||||
# # NEXT: we still want to handle the LLM response stream, but it is now:
|
||||
# # 1. handle the tool call requests
|
||||
# # 2. feed back the processed results
|
||||
# # 3. handle the citations
|
||||
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=final_search_results,
|
||||
# final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
# display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
# )
|
||||
|
||||
# At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
# # At the moment, this wrapper class passes streamed stuff through citation and tool handlers.
|
||||
# # In the future, we'll want to handle citations and tool calls in the langgraph graph.
|
||||
# response_handler_manager = LLMResponseHandlerManager(
|
||||
# tool_call_handler, answer_handler, self.is_cancelled
|
||||
# )
|
||||
|
||||
# In langgraph, whether we do the basic thing (call llm stream) or pro search
|
||||
# is based on a flag in the pro search config
|
||||
|
||||
if self.pro_search_config.use_agentic_search:
|
||||
if self.pro_search_config.search_request is None:
|
||||
raise ValueError("Search request must be provided for pro search")
|
||||
|
||||
if self.db_session is None:
|
||||
raise ValueError("db_session must be provided for pro search")
|
||||
if self.fast_llm is None:
|
||||
raise ValueError("fast_llm must be provided for pro search")
|
||||
if self.agent_search_config.use_agentic_search:
|
||||
if (
|
||||
self.agent_search_config.db_session is None
|
||||
and self.agent_search_config.use_persistence
|
||||
):
|
||||
raise ValueError(
|
||||
"db_session must be provided for pro search when using persistence"
|
||||
)
|
||||
|
||||
stream = run_main_graph(
|
||||
config=self.pro_search_config,
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
else:
|
||||
stream = run_basic_graph(
|
||||
config=self.pro_search_config,
|
||||
last_llm_call=current_llm_call,
|
||||
response_handler_manager=response_handler_manager,
|
||||
config=self.agent_search_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
return
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
tool_choice=(
|
||||
"required"
|
||||
if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
else None
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
# stream = self.llm.stream(
|
||||
# # For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
|
||||
# # may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
|
||||
# prompt=current_llm_call.prompt_builder.build(),
|
||||
# tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
# tool_choice=(
|
||||
# "required"
|
||||
# if current_llm_call.tools and current_llm_call.force_use_tool.force_use
|
||||
# else None
|
||||
# ),
|
||||
# structured_response_format=self.answer_style_config.structured_response_format,
|
||||
# )
|
||||
# yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
# new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
# if new_llm_call:
|
||||
# yield from self._get_response(llm_calls + [new_llm_call])
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@ -263,33 +251,33 @@ class Answer:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=self.question,
|
||||
prompt_config=self.prompt_config,
|
||||
files=self.latest_query_files,
|
||||
single_message_history=self.single_message_history,
|
||||
),
|
||||
message_history=self.message_history,
|
||||
llm_config=self.llm.config,
|
||||
raw_user_query=self.question,
|
||||
raw_user_uploaded_files=self.latest_query_files or [],
|
||||
single_message_history=self.single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
# prompt_builder = AnswerPromptBuilder(
|
||||
# user_message=default_build_user_message(
|
||||
# user_query=self.question,
|
||||
# prompt_config=self.prompt_config,
|
||||
# files=self.latest_query_files,
|
||||
# single_message_history=self.single_message_history,
|
||||
# ),
|
||||
# message_history=self.message_history,
|
||||
# llm_config=self.llm.config,
|
||||
# raw_user_query=self.question,
|
||||
# raw_user_uploaded_files=self.latest_query_files or [],
|
||||
# single_message_history=self.single_message_history,
|
||||
# )
|
||||
# prompt_builder.update_system_prompt(
|
||||
# default_build_system_message(self.prompt_config)
|
||||
# )
|
||||
# llm_call = LLMCall(
|
||||
# prompt_builder=prompt_builder,
|
||||
# tools=self._get_tools_list(),
|
||||
# force_use_tool=self.force_use_tool,
|
||||
# files=self.latest_query_files,
|
||||
# tool_call_info=[],
|
||||
# using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
# )
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
for processed_packet in self._get_response():
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
|
@ -49,6 +49,7 @@ def prepare_chat_message_request(
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
use_agentic_search: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
@ -74,6 +75,7 @@ def prepare_chat_message_request(
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
|
||||
|
@ -33,6 +33,9 @@ from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@ -130,6 +133,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@ -137,7 +141,6 @@ from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -534,11 +537,8 @@ def stream_chat_message_objects(
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
)
|
||||
latest_query_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_id in [f["id"] for f in new_msg_req.file_descriptors]
|
||||
]
|
||||
req_file_ids = [f["id"] for f in new_msg_req.file_descriptors]
|
||||
latest_query_files = [file for file in files if file.file_id in req_file_ids]
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@ -748,17 +748,42 @@ def stream_chat_message_objects(
|
||||
# TODO: handle multiple search tools
|
||||
raise ValueError("Multiple search tools found")
|
||||
search_tool = search_tools[0]
|
||||
pro_search_config = AgentSearchConfig(
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
search_request=search_request,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
using_tool_calling_llm = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
force_use_tool = _get_force_search_settings(new_msg_req, tools)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
prompt_config=prompt_config,
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
agent_search_config = AgentSearchConfig(
|
||||
search_request=search_request,
|
||||
primary_llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool=search_tool,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
force_use_tool=force_use_tool,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=reserved_message_id,
|
||||
use_persistence=True,
|
||||
allow_refinement=True,
|
||||
db_session=db_session,
|
||||
prompt_builder=prompt_builder,
|
||||
tools=tools,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
files=latest_query_files,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# TODO: add previous messages, answer style config, tools, etc.
|
||||
@ -785,9 +810,9 @@ def stream_chat_message_objects(
|
||||
fast_llm=fast_llm,
|
||||
message_history=message_history,
|
||||
tools=tools,
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
force_use_tool=force_use_tool,
|
||||
single_message_history=single_message_history,
|
||||
pro_search_config=pro_search_config,
|
||||
agent_search_config=agent_search_config,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -84,6 +84,7 @@ class AnswerPromptBuilder:
|
||||
raw_user_query: str,
|
||||
raw_user_uploaded_files: list[InMemoryChatFile],
|
||||
single_message_history: str | None = None,
|
||||
system_message: SystemMessage | None = None,
|
||||
) -> None:
|
||||
self.max_tokens = compute_max_llm_input_tokens(llm_config)
|
||||
|
||||
@ -108,7 +109,14 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
|
@ -5,6 +5,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from onyx.chat.models import ResponsePart
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@ -50,56 +51,12 @@ class ToolResponseHandler:
|
||||
def get_tool_call_for_non_tool_calling_llm(
|
||||
cls, llm_call: LLMCall, llm: LLM
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if llm_call.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = get_tool_by_name(llm_call.tools, llm_call.force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
llm_call.force_use_tool.args
|
||||
if llm_call.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=llm_call.tools,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(llm_call.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=llm_call.prompt_builder.raw_message_history,
|
||||
query=llm_call.prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
return get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool=llm_call.force_use_tool,
|
||||
tools=llm_call.tools,
|
||||
prompt_builder=llm_call.prompt_builder,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
def _handle_tool_call(self) -> Generator[ResponsePart, None, None]:
|
||||
if not self.tool_call_chunk or not self.tool_call_chunk.tool_calls:
|
||||
@ -196,3 +153,61 @@ class ToolResponseHandler:
|
||||
self.tool_final_result,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def get_tool_call_for_non_tool_calling_llm_impl(
|
||||
force_use_tool: ForceUseTool,
|
||||
tools: list[Tool],
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict] | None:
|
||||
if force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = get_tool_by_name(tools, force_use_tool.tool_name)
|
||||
|
||||
tool_args = (
|
||||
force_use_tool.args
|
||||
if force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
return (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=tools,
|
||||
query=prompt_builder.raw_user_query,
|
||||
history=prompt_builder.raw_message_history,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=prompt_builder.raw_message_history,
|
||||
query=prompt_builder.raw_user_query,
|
||||
llm=llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
return chosen_tool_and_args
|
||||
|
@ -81,7 +81,4 @@ def check_if_need_search(
|
||||
|
||||
logger.debug(f"Run search prediction: {require_search_output}")
|
||||
|
||||
if (SKIP_SEARCH.split()[0]).lower() in require_search_output.lower():
|
||||
return False
|
||||
|
||||
return True
|
||||
return (SKIP_SEARCH.split()[0]).lower() not in require_search_output.lower()
|
||||
|
@ -138,6 +138,8 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# TODO: decide how many of the above options we want to pass through to pro search
|
||||
use_agentic_search: bool = False
|
||||
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
|
@ -20,10 +20,7 @@ OPEN_AI_TOOL_CALLING_MODELS = {
|
||||
|
||||
|
||||
def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool:
|
||||
if model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS:
|
||||
return True
|
||||
|
||||
return False
|
||||
return model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS
|
||||
|
||||
|
||||
def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int:
|
||||
|
@ -3,15 +3,20 @@ from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
@ -27,6 +32,31 @@ def answer_style_config() -> AnswerStyleConfig:
|
||||
return AnswerStyleConfig(citation_config=CitationConfig())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_search_config(
|
||||
mock_llm: LLM, mock_search_tool: SearchTool
|
||||
) -> AgentSearchConfig:
|
||||
return AgentSearchConfig(
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
primary_llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
search_tool=mock_search_tool,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name=""),
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
user_message=HumanMessage(content=QUERY),
|
||||
message_history=[],
|
||||
llm_config=mock_llm.config,
|
||||
raw_user_query=QUERY,
|
||||
raw_user_uploaded_files=[],
|
||||
),
|
||||
chat_session_id=None,
|
||||
message_id=1,
|
||||
use_persistence=True,
|
||||
db_session=None,
|
||||
use_agentic_search=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def prompt_config() -> PromptConfig:
|
||||
return PromptConfig(
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
|
||||
from onyx.agents.agent_search.models import AgentSearchConfig
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
@ -30,7 +31,10 @@ from tests.unit.onyx.chat.conftest import QUERY
|
||||
|
||||
@pytest.fixture
|
||||
def answer_instance(
|
||||
mock_llm: LLM, answer_style_config: AnswerStyleConfig, prompt_config: PromptConfig
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
agent_search_config: AgentSearchConfig,
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
question=QUERY,
|
||||
@ -38,6 +42,7 @@ def answer_instance(
|
||||
llm=mock_llm,
|
||||
prompt_config=prompt_config,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
agent_search_config=agent_search_config,
|
||||
)
|
||||
|
||||
|
||||
@ -284,7 +289,8 @@ def test_answer_with_search_no_tool_calling(
|
||||
def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
# Set up the LLM mock to return multiple chunks
|
||||
mock_llm = Mock()
|
||||
answer_instance.llm = mock_llm
|
||||
answer_instance.agent_search_config.primary_llm = mock_llm
|
||||
answer_instance.agent_search_config.fast_llm = mock_llm
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is the "),
|
||||
AIMessageChunk(content="first part."),
|
||||
@ -303,6 +309,7 @@ def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
if i == 1:
|
||||
connection_status["connected"] = False
|
||||
|
||||
print(output)
|
||||
assert len(output) == 3
|
||||
assert output[0] == OnyxAnswerPiece(answer_piece="This is the ")
|
||||
assert output[1] == OnyxAnswerPiece(answer_piece="first part.")
|
||||
|
@ -5,10 +5,12 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.answer import AnswerStream
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||
@ -41,6 +43,12 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_llm.config.model_name = "gpt-4o-mini"
|
||||
mock_llm.stream = Mock()
|
||||
mock_llm.stream.return_value = [Mock()]
|
||||
|
||||
session = Mock()
|
||||
agent_search_config, _ = get_test_config(
|
||||
session, mock_llm, mock_llm, SearchRequest(query=question)
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=question,
|
||||
answer_style_config=answer_style_config,
|
||||
@ -58,6 +66,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
agent_search_config=agent_search_config,
|
||||
)
|
||||
count = 0
|
||||
for _ in cast(AnswerStream, answer.processed_streamed_output):
|
||||
|
Loading…
x
Reference in New Issue
Block a user