diff --git a/backend/onyx/agents/agent_search/models.py b/backend/onyx/agents/agent_search/models.py index 0001f56c9..d9358433b 100644 --- a/backend/onyx/agents/agent_search/models.py +++ b/backend/onyx/agents/agent_search/models.py @@ -83,7 +83,7 @@ class GraphConfig(BaseModel): tooling: GraphTooling behavior: GraphSearchConfig # Only needed for agentic search - persistence: GraphPersistence | None = None + persistence: GraphPersistence @model_validator(mode="after") def validate_search_tool(self) -> "GraphConfig": diff --git a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py index dce3c13d8..bce492c4e 100644 --- a/backend/onyx/agents/agent_search/shared_graph_utils/utils.py +++ b/backend/onyx/agents/agent_search/shared_graph_utils/utils.py @@ -199,13 +199,15 @@ def get_test_config( using_tool_calling_llm=using_tool_calling_llm, ) - graph_persistence = None - if chat_session_id := os.environ.get("ONYX_AS_CHAT_SESSION_ID"): - graph_persistence = GraphPersistence( - db_session=db_session, - chat_session_id=UUID(chat_session_id), - message_id=1, - ) + chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID") + assert ( + chat_session_id is not None + ), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests" + graph_persistence = GraphPersistence( + db_session=db_session, + chat_session_id=UUID(chat_session_id), + message_id=1, + ) search_behavior_config = GraphSearchConfig( use_agentic_search=use_agentic_search, diff --git a/backend/onyx/chat/answer.py b/backend/onyx/chat/answer.py index 88be1a9f7..303f1e6a3 100644 --- a/backend/onyx/chat/answer.py +++ b/backend/onyx/chat/answer.py @@ -47,6 +47,7 @@ class Answer: search_request: SearchRequest, chat_session_id: UUID, current_agent_message_id: int, + db_session: Session, # newly passed in files to include as part of this question # TODO THIS NEEDS TO BE HANDLED latest_query_files: list[InMemoryChatFile] | None = None, @@ -57,9 +58,7 @@ class Answer: skip_explicit_tool_calling: bool = False, skip_gen_ai_answer_generation: bool = False, is_connected: Callable[[], bool] | None = None, - db_session: Session | None = None, use_agentic_search: bool = False, - use_agentic_persistence: bool = True, ) -> None: self.is_connected: Callable[[], bool] | None = is_connected self._processed_stream: (list[AnswerPacket] | None) = None @@ -95,14 +94,12 @@ class Answer: force_use_tool=force_use_tool, using_tool_calling_llm=using_tool_calling_llm, ) - self.graph_persistence = None - if use_agentic_persistence: - assert db_session, "db_session must be provided for agentic persistence" - self.graph_persistence = GraphPersistence( - db_session=db_session, - chat_session_id=chat_session_id, - message_id=current_agent_message_id, - ) + assert db_session, "db_session must be provided for agentic persistence" + self.graph_persistence = GraphPersistence( + db_session=db_session, + chat_session_id=chat_session_id, + message_id=current_agent_message_id, + ) self.search_behavior_config = GraphSearchConfig( use_agentic_search=use_agentic_search, skip_gen_ai_answer_generation=skip_gen_ai_answer_generation, diff --git a/backend/tests/unit/onyx/chat/test_answer.py b/backend/tests/unit/onyx/chat/test_answer.py index fe2520fd6..763a3b6f6 100644 --- a/backend/tests/unit/onyx/chat/test_answer.py +++ b/backend/tests/unit/onyx/chat/test_answer.py @@ -11,6 +11,7 @@ from langchain_core.messages import HumanMessage from langchain_core.messages import SystemMessage from langchain_core.messages import ToolCall from langchain_core.messages import ToolCallChunk +from sqlalchemy.orm import Session from onyx.chat.answer import Answer from onyx.chat.models import AnswerStyleConfig @@ -58,6 +59,7 @@ def answer_instance( raw_user_query=QUERY, raw_user_uploaded_files=[], ), + db_session=Mock(spec=Session), answer_style_config=answer_style_config, llm=mock_llm, fast_llm=mock_llm, @@ -65,7 +67,6 @@ def answer_instance( search_request=SearchRequest(query=QUERY), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), current_agent_message_id=0, - use_agentic_persistence=False, ) diff --git a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py index e00a61893..00f77279a 100644 --- a/backend/tests/unit/onyx/chat/test_skip_gen_ai.py +++ b/backend/tests/unit/onyx/chat/test_skip_gen_ai.py @@ -5,6 +5,7 @@ from uuid import UUID import pytest from langchain_core.messages import HumanMessage from pytest_mock import MockerFixture +from sqlalchemy.orm import Session from onyx.chat.answer import Answer from onyx.chat.models import AnswerStyleConfig @@ -46,6 +47,7 @@ def test_skip_gen_ai_answer_generation_flag( mock_llm.stream.return_value = [Mock()] answer = Answer( + db_session=Mock(spec=Session), answer_style_config=answer_style_config, llm=mock_llm, fast_llm=mock_llm, @@ -67,7 +69,6 @@ def test_skip_gen_ai_answer_generation_flag( ), chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"), current_agent_message_id=0, - use_agentic_persistence=False, ) results = list(answer.processed_streamed_output) for res in results: