always persist in agent search

This commit is contained in:
Evan Lohn
2025-02-01 23:20:01 -08:00
parent 2adeaaeded
commit 71304e4228
5 changed files with 21 additions and 20 deletions

View File

@@ -83,7 +83,7 @@ class GraphConfig(BaseModel):
tooling: GraphTooling tooling: GraphTooling
behavior: GraphSearchConfig behavior: GraphSearchConfig
# Only needed for agentic search # Only needed for agentic search
persistence: GraphPersistence | None = None persistence: GraphPersistence
@model_validator(mode="after") @model_validator(mode="after")
def validate_search_tool(self) -> "GraphConfig": def validate_search_tool(self) -> "GraphConfig":

View File

@@ -199,13 +199,15 @@ def get_test_config(
using_tool_calling_llm=using_tool_calling_llm, using_tool_calling_llm=using_tool_calling_llm,
) )
graph_persistence = None chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
if chat_session_id := os.environ.get("ONYX_AS_CHAT_SESSION_ID"): assert (
graph_persistence = GraphPersistence( chat_session_id is not None
db_session=db_session, ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
chat_session_id=UUID(chat_session_id), graph_persistence = GraphPersistence(
message_id=1, db_session=db_session,
) chat_session_id=UUID(chat_session_id),
message_id=1,
)
search_behavior_config = GraphSearchConfig( search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search, use_agentic_search=use_agentic_search,

View File

@@ -47,6 +47,7 @@ class Answer:
search_request: SearchRequest, search_request: SearchRequest,
chat_session_id: UUID, chat_session_id: UUID,
current_agent_message_id: int, current_agent_message_id: int,
db_session: Session,
# newly passed in files to include as part of this question # newly passed in files to include as part of this question
# TODO THIS NEEDS TO BE HANDLED # TODO THIS NEEDS TO BE HANDLED
latest_query_files: list[InMemoryChatFile] | None = None, latest_query_files: list[InMemoryChatFile] | None = None,
@@ -57,9 +58,7 @@ class Answer:
skip_explicit_tool_calling: bool = False, skip_explicit_tool_calling: bool = False,
skip_gen_ai_answer_generation: bool = False, skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None, is_connected: Callable[[], bool] | None = None,
db_session: Session | None = None,
use_agentic_search: bool = False, use_agentic_search: bool = False,
use_agentic_persistence: bool = True,
) -> None: ) -> None:
self.is_connected: Callable[[], bool] | None = is_connected self.is_connected: Callable[[], bool] | None = is_connected
self._processed_stream: (list[AnswerPacket] | None) = None self._processed_stream: (list[AnswerPacket] | None) = None
@@ -95,14 +94,12 @@ class Answer:
force_use_tool=force_use_tool, force_use_tool=force_use_tool,
using_tool_calling_llm=using_tool_calling_llm, using_tool_calling_llm=using_tool_calling_llm,
) )
self.graph_persistence = None assert db_session, "db_session must be provided for agentic persistence"
if use_agentic_persistence: self.graph_persistence = GraphPersistence(
assert db_session, "db_session must be provided for agentic persistence" db_session=db_session,
self.graph_persistence = GraphPersistence( chat_session_id=chat_session_id,
db_session=db_session, message_id=current_agent_message_id,
chat_session_id=chat_session_id, )
message_id=current_agent_message_id,
)
self.search_behavior_config = GraphSearchConfig( self.search_behavior_config = GraphSearchConfig(
use_agentic_search=use_agentic_search, use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk from langchain_core.messages import ToolCallChunk
from sqlalchemy.orm import Session
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
@@ -58,6 +59,7 @@ def answer_instance(
raw_user_query=QUERY, raw_user_query=QUERY,
raw_user_uploaded_files=[], raw_user_uploaded_files=[],
), ),
db_session=Mock(spec=Session),
answer_style_config=answer_style_config, answer_style_config=answer_style_config,
llm=mock_llm, llm=mock_llm,
fast_llm=mock_llm, fast_llm=mock_llm,
@@ -65,7 +67,6 @@ def answer_instance(
search_request=SearchRequest(query=QUERY), search_request=SearchRequest(query=QUERY),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0, current_agent_message_id=0,
use_agentic_persistence=False,
) )

View File

@@ -5,6 +5,7 @@ from uuid import UUID
import pytest import pytest
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from sqlalchemy.orm import Session
from onyx.chat.answer import Answer from onyx.chat.answer import Answer
from onyx.chat.models import AnswerStyleConfig from onyx.chat.models import AnswerStyleConfig
@@ -46,6 +47,7 @@ def test_skip_gen_ai_answer_generation_flag(
mock_llm.stream.return_value = [Mock()] mock_llm.stream.return_value = [Mock()]
answer = Answer( answer = Answer(
db_session=Mock(spec=Session),
answer_style_config=answer_style_config, answer_style_config=answer_style_config,
llm=mock_llm, llm=mock_llm,
fast_llm=mock_llm, fast_llm=mock_llm,
@@ -67,7 +69,6 @@ def test_skip_gen_ai_answer_generation_flag(
), ),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0, current_agent_message_id=0,
use_agentic_persistence=False,
) )
results = list(answer.processed_streamed_output) results = list(answer.processed_streamed_output)
for res in results: for res in results: