diff --git a/backend/ee/onyx/server/query_and_chat/chat_backend.py b/backend/ee/onyx/server/query_and_chat/chat_backend.py index 5a3ba2902..a2c08b543 100644 --- a/backend/ee/onyx/server/query_and_chat/chat_backend.py +++ b/backend/ee/onyx/server/query_and_chat/chat_backend.py @@ -1,10 +1,14 @@ import re +from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session +from ee.onyx.server.query_and_chat.models import AgentAnswer +from ee.onyx.server.query_and_chat.models import AgentSubQuery +from ee.onyx.server.query_and_chat.models import AgentSubQuestion from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest from ee.onyx.server.query_and_chat.models import ( BasicCreateChatMessageWithHistoryRequest, @@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc from onyx.auth.users import current_user from onyx.chat.chat_utils import combine_message_thread from onyx.chat.chat_utils import create_chat_chain +from onyx.chat.models import AgentAnswerPiece from onyx.chat.models import AllCitations +from onyx.chat.models import ExtendedToolResponse from onyx.chat.models import FinalUsedContextDocsResponse from onyx.chat.models import LlmDoc from onyx.chat.models import LLMRelevanceFilterResponse from onyx.chat.models import OnyxAnswerPiece from onyx.chat.models import QADocsResponse +from onyx.chat.models import RefinedAnswerImprovement from onyx.chat.models import StreamingError +from onyx.chat.models import SubQueryPiece +from onyx.chat.models import SubQuestionIdentifier +from onyx.chat.models import SubQuestionPiece from onyx.chat.process_message import ChatPacketStream from onyx.chat.process_message import stream_chat_message_objects from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE @@ -89,6 +99,12 @@ def _convert_packet_stream_to_response( final_context_docs: list[LlmDoc] = [] answer = "" + + # accumulate stream data with these dicts + agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {} + agent_answers: dict[tuple[int, int], AgentAnswer] = {} + agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {} + for packet in packets: if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece: answer += packet.answer_piece @@ -97,6 +113,15 @@ def _convert_packet_stream_to_response( # TODO: deprecate `simple_search_docs` response.simple_search_docs = _translate_doc_response_to_simple_doc(packet) + + # This is a no-op if agent_sub_questions hasn't already been filled + if packet.level is not None and packet.level_question_num is not None: + id = (packet.level, packet.level_question_num) + if id in agent_sub_questions: + agent_sub_questions[id].document_ids = [ + saved_search_doc.document_id + for saved_search_doc in packet.top_documents + ] elif isinstance(packet, StreamingError): response.error_msg = packet.error elif isinstance(packet, ChatMessageDetail): @@ -113,11 +138,104 @@ def _convert_packet_stream_to_response( citation.citation_num: citation.document_id for citation in packet.citations } + # agentic packets + elif isinstance(packet, SubQuestionPiece): + if packet.level is not None and packet.level_question_num is not None: + id = (packet.level, packet.level_question_num) + if agent_sub_questions.get(id) is None: + agent_sub_questions[id] = AgentSubQuestion( + level=packet.level, + level_question_num=packet.level_question_num, + sub_question=packet.sub_question, + document_ids=[], + ) + else: + agent_sub_questions[id].sub_question += packet.sub_question + + elif isinstance(packet, AgentAnswerPiece): + if packet.level is not None and packet.level_question_num is not None: + id = (packet.level, packet.level_question_num) + if agent_answers.get(id) is None: + agent_answers[id] = AgentAnswer( + level=packet.level, + level_question_num=packet.level_question_num, + answer=packet.answer_piece, + answer_type=packet.answer_type, + ) + else: + agent_answers[id].answer += packet.answer_piece + elif isinstance(packet, SubQueryPiece): + if packet.level is not None and packet.level_question_num is not None: + sub_query_id = ( + packet.level, + packet.level_question_num, + packet.query_id, + ) + if agent_sub_queries.get(sub_query_id) is None: + agent_sub_queries[sub_query_id] = AgentSubQuery( + level=packet.level, + level_question_num=packet.level_question_num, + sub_query=packet.sub_query, + query_id=packet.query_id, + ) + else: + agent_sub_queries[sub_query_id].sub_query += packet.sub_query + elif isinstance(packet, ExtendedToolResponse): + # we shouldn't get this ... it gets intercepted and translated to QADocsResponse + logger.warning( + "_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!" + ) + elif isinstance(packet, RefinedAnswerImprovement): + response.agent_refined_answer_improvement = ( + packet.refined_answer_improvement + ) + else: + logger.warning( + f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}" + ) response.final_context_doc_indices = _get_final_context_doc_indices( final_context_docs, response.top_documents ) + # organize / sort agent metadata for output + if len(agent_sub_questions) > 0: + response.agent_sub_questions = cast( + dict[int, list[AgentSubQuestion]], + SubQuestionIdentifier.make_dict_by_level(agent_sub_questions), + ) + + if len(agent_answers) > 0: + # return the agent_level_answer from the first level or the last one depending + # on agent_refined_answer_improvement + response.agent_answers = cast( + dict[int, list[AgentAnswer]], + SubQuestionIdentifier.make_dict_by_level(agent_answers), + ) + if response.agent_answers: + selected_answer_level = ( + 0 + if not response.agent_refined_answer_improvement + else len(response.agent_answers) - 1 + ) + level_answers = response.agent_answers[selected_answer_level] + for level_answer in level_answers: + if level_answer.answer_type != "agent_level_answer": + continue + + answer = level_answer.answer + break + + if len(agent_sub_queries) > 0: + # subqueries are often emitted with trailing whitespace ... clean it up here + # perhaps fix at the source? + for v in agent_sub_queries.values(): + v.sub_query = v.sub_query.strip() + + response.agent_sub_queries = ( + AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries) + ) + response.answer = answer if answer: response.answer_citationless = remove_answer_citations(answer) diff --git a/backend/ee/onyx/server/query_and_chat/models.py b/backend/ee/onyx/server/query_and_chat/models.py index bf08a39eb..4f493d73a 100644 --- a/backend/ee/onyx/server/query_and_chat/models.py +++ b/backend/ee/onyx/server/query_and_chat/models.py @@ -1,3 +1,5 @@ +from collections import OrderedDict +from typing import Literal from uuid import UUID from pydantic import BaseModel @@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo from onyx.chat.models import OnyxContexts from onyx.chat.models import PersonaOverrideConfig from onyx.chat.models import QADocsResponse +from onyx.chat.models import SubQuestionIdentifier from onyx.chat.models import ThreadMessage from onyx.configs.constants import DocumentSource from onyx.context.search.enums import LLMEvaluationType @@ -88,6 +91,64 @@ class SimpleDoc(BaseModel): metadata: dict | None +class AgentSubQuestion(SubQuestionIdentifier): + sub_question: str + document_ids: list[str] + + +class AgentAnswer(SubQuestionIdentifier): + answer: str + answer_type: Literal["agent_sub_answer", "agent_level_answer"] + + +class AgentSubQuery(SubQuestionIdentifier): + sub_query: str + query_id: int + + @staticmethod + def make_dict_by_level_and_question_index( + original_dict: dict[tuple[int, int, int], "AgentSubQuery"] + ) -> dict[int, dict[int, list["AgentSubQuery"]]]: + """Takes a dict of tuple(level, question num, query_id) to sub queries. + + returns a dict of level to dict[question num to list of query_id's] + Ordering is asc for readability. + """ + # In this function, when we sort int | None, we deliberately push None to the end + + # map entries to the level_question_dict + level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {} + for k1, obj in original_dict.items(): + level = k1[0] + question = k1[1] + + if level not in level_question_dict: + level_question_dict[level] = {} + + if question not in level_question_dict[level]: + level_question_dict[level][question] = [] + + level_question_dict[level][question].append(obj) + + # sort each query_id list and question_index + for key1, obj1 in level_question_dict.items(): + for key2, value2 in obj1.items(): + # sort the query_id list of each question_index + level_question_dict[key1][key2] = sorted( + value2, key=lambda o: o.query_id + ) + # sort the question_index dict of level + level_question_dict[key1] = OrderedDict( + sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x)) + ) + + # sort the top dict of levels + sorted_dict = OrderedDict( + sorted(level_question_dict.items(), key=lambda x: (x is None, x)) + ) + return sorted_dict + + class ChatBasicResponse(BaseModel): # This is built piece by piece, any of these can be None as the flow could break answer: str | None = None @@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel): simple_search_docs: list[SimpleDoc] | None = None llm_chunks_indices: list[int] | None = None + # agentic fields + agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None + agent_answers: dict[int, list[AgentAnswer]] | None = None + agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None + agent_refined_answer_improvement: bool | None = None + class OneShotQARequest(ChunkContext): # Supports simplier APIs that don't deal with chat histories or message edits diff --git a/backend/onyx/auth/users.py b/backend/onyx/auth/users.py index 4644521b1..05a49ccab 100644 --- a/backend/onyx/auth/users.py +++ b/backend/onyx/auth/users.py @@ -895,7 +895,7 @@ async def current_limited_user( return await double_check_user(user) -async def current_chat_accesssible_user( +async def current_chat_accessible_user( user: User | None = Depends(optional_user), ) -> User | None: tenant_id = get_current_tenant_id() diff --git a/backend/onyx/chat/models.py b/backend/onyx/chat/models.py index 6b1c43437..976ed336a 100644 --- a/backend/onyx/chat/models.py +++ b/backend/onyx/chat/models.py @@ -1,10 +1,13 @@ +from collections import OrderedDict from collections.abc import Callable from collections.abc import Iterator +from collections.abc import Mapping from datetime import datetime from enum import Enum from typing import Any from typing import Literal from typing import TYPE_CHECKING +from typing import Union from pydantic import BaseModel from pydantic import ConfigDict @@ -44,9 +47,44 @@ class LlmDoc(BaseModel): class SubQuestionIdentifier(BaseModel): + """None represents references to objects in the original flow. To our understanding, + these will not be None in the packets returned from agent search. + """ + level: int | None = None level_question_num: int | None = None + @staticmethod + def make_dict_by_level( + original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"] + ) -> dict[int, list["SubQuestionIdentifier"]]: + """returns a dict of level to object list (sorted by level_question_num) + Ordering is asc for readability. + """ + + # organize by level, then sort ascending by question_index + level_dict: dict[int, list[SubQuestionIdentifier]] = {} + + # group by level + for k, obj in original_dict.items(): + level = k[0] + if level not in level_dict: + level_dict[level] = [] + level_dict[level].append(obj) + + # for each level, sort the group + for k2, value2 in level_dict.items(): + # we need to handle the none case due to SubQuestionIdentifier typing + # level_question_num as int | None, even though it should never be None here. + level_dict[k2] = sorted( + value2, + key=lambda x: (x.level_question_num is None, x.level_question_num), + ) + + # sort by level + sorted_dict = OrderedDict(sorted(level_dict.items())) + return sorted_dict + # First chunk of info for streaming QA class QADocsResponse(RetrievalDocs, SubQuestionIdentifier): @@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier): class SubQuestionPiece(SubQuestionIdentifier): + """Refined sub questions generated from the initial user question.""" + sub_question: str @@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel): refined_answer_improvement: bool -AgentSearchPacket = ( +AgentSearchPacket = Union[ SubQuestionPiece | AgentAnswerPiece | SubQueryPiece | ExtendedToolResponse | RefinedAnswerImprovement -) +] AnswerPacket = ( AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse diff --git a/backend/onyx/main.py b/backend/onyx/main.py index e783055ee..f4d296451 100644 --- a/backend/onyx/main.py +++ b/backend/onyx/main.py @@ -234,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: yield + SqlEngine.reset_engine() + if AUTH_RATE_LIMITING_ENABLED: await close_auth_limiter() diff --git a/backend/onyx/server/auth_check.py b/backend/onyx/server/auth_check.py index c1bbb6b46..a337f3660 100644 --- a/backend/onyx/server/auth_check.py +++ b/backend/onyx/server/auth_check.py @@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant from starlette.routing import BaseRoute from onyx.auth.users import current_admin_user -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_limited_user from onyx.auth.users import current_user @@ -112,7 +112,7 @@ def check_router_auth( or depends_fn == current_curator_or_admin_user or depends_fn == api_key_dep or depends_fn == current_user_with_expired_token - or depends_fn == current_chat_accesssible_user + or depends_fn == current_chat_accessible_user or depends_fn == control_plane_dep or depends_fn == current_cloud_superuser ): diff --git a/backend/onyx/server/documents/connector.py b/backend/onyx/server/documents/connector.py index 60511ae93..920c02b18 100644 --- a/backend/onyx/server/documents/connector.py +++ b/backend/onyx/server/documents/connector.py @@ -17,7 +17,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.background.celery.versioned_apps.primary import app as primary_app @@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel): @router.get("/connector-status") def get_basic_connector_indexing_status( - user: User = Depends(current_chat_accesssible_user), + user: User = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> list[BasicCCPairInfo]: cc_pairs = get_connector_credential_pairs_for_user( diff --git a/backend/onyx/server/features/persona/api.py b/backend/onyx/server/features/persona/api.py index d022244ea..8d6c9b014 100644 --- a/backend/onyx/server/features/persona/api.py +++ b/backend/onyx/server/features/persona/api.py @@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_limited_user from onyx.auth.users import current_user @@ -390,7 +390,7 @@ def get_image_generation_tool( @basic_router.get("") def list_personas( - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), include_deleted: bool = False, persona_ids: list[int] = Query(None), diff --git a/backend/onyx/server/manage/llm/api.py b/backend/onyx/server/manage/llm/api.py index ad0bc9742..7a76ed196 100644 --- a/backend/onyx/server/manage/llm/api.py +++ b/backend/onyx/server/manage/llm/api.py @@ -7,7 +7,7 @@ from fastapi import Query from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.db.engine import get_session from onyx.db.llm import fetch_existing_llm_providers from onyx.db.llm import fetch_existing_llm_providers_for_user @@ -191,7 +191,7 @@ def set_provider_as_default( @basic_router.get("/provider") def list_llm_provider_basics( - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> list[LLMProviderDescriptor]: return [ diff --git a/backend/onyx/server/query_and_chat/chat_backend.py b/backend/onyx/server/query_and_chat/chat_backend.py index c2043ff78..ad3a3a18b 100644 --- a/backend/onyx/server/query_and_chat/chat_backend.py +++ b/backend/onyx/server/query_and_chat/chat_backend.py @@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse from pydantic import BaseModel from sqlalchemy.orm import Session -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_user from onyx.chat.chat_utils import create_chat_chain from onyx.chat.chat_utils import extract_headers @@ -190,7 +190,7 @@ def update_chat_session_model( def get_chat_session( session_id: UUID, is_shared: bool = False, - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> ChatSessionDetailResponse: user_id = user.id if user is not None else None @@ -246,7 +246,7 @@ def get_chat_session( @router.post("/create-chat-session") def create_new_chat_session( chat_session_creation_request: ChatSessionCreationRequest, - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> CreateChatSessionID: user_id = user.id if user is not None else None @@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]: def handle_new_chat_message( chat_message_req: CreateChatMessageRequest, request: Request, - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), _rate_limit_check: None = Depends(check_token_rate_limits), is_connected_func: Callable[[], bool] = Depends(is_connected), ) -> StreamingResponse: @@ -473,7 +473,7 @@ def set_message_as_latest( @router.post("/create-chat-message-feedback") def create_chat_feedback( feedback: ChatFeedbackRequest, - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> None: user_id = user.id if user else None diff --git a/backend/onyx/server/query_and_chat/token_limit.py b/backend/onyx/server/query_and_chat/token_limit.py index fc0bc629d..28c1494a1 100644 --- a/backend/onyx/server/query_and_chat/token_limit.py +++ b/backend/onyx/server/query_and_chat/token_limit.py @@ -11,7 +11,7 @@ from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session -from onyx.auth.users import current_chat_accesssible_user +from onyx.auth.users import current_chat_accessible_user from onyx.db.engine import get_session_context_manager from onyx.db.models import ChatMessage from onyx.db.models import ChatSession @@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000 def check_token_rate_limits( - user: User | None = Depends(current_chat_accesssible_user), + user: User | None = Depends(current_chat_accessible_user), ) -> None: # short circuit if no rate limits are set up # NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time diff --git a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py index dc4361301..cc507b137 100644 --- a/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py +++ b/backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py @@ -1,3 +1,8 @@ +from typing import Any + +import pytest + +from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager @@ -17,3 +22,58 @@ def test_send_message_simple_with_history(reset: None) -> None: ) assert len(response.full_message) > 0 + + +@pytest.mark.skip( + reason="enable for autorun when we have a testing environment with semantically useful data" +) +def test_send_message_simple_with_history_buffered() -> None: + import requests + + API_KEY = "" # fill in for this to work + headers = {} + headers["Authorization"] = f"Bearer {API_KEY}" + + req: dict[str, Any] = {} + + req["persona_id"] = 0 + req["description"] = "test_send_message_simple_with_history_buffered" + response = requests.post( + f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req + ) + chat_session_id = response.json()["chat_session_id"] + + req = {} + req["chat_session_id"] = chat_session_id + req["message"] = "What does onyx do?" + req["use_agentic_search"] = True + + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req + ) + + r_json = response.json() + + # all of these should exist and be greater than length 1 + assert len(r_json.get("answer", "")) > 0 + assert len(r_json.get("agent_sub_questions", "")) > 0 + assert len(r_json.get("agent_answers")) > 0 + assert len(r_json.get("agent_sub_queries")) > 0 + assert "agent_refined_answer_improvement" in r_json + + # top level answer should match the one we select out of agent_answers + answer_level = 0 + agent_level_answer = "" + + agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement") + if agent_refined_answer_improvement: + answer_level = len(r_json["agent_answers"]) - 1 + + answers = r_json["agent_answers"][str(answer_level)] + for answer in answers: + if answer["answer_type"] == "agent_level_answer": + agent_level_answer = answer["answer"] + break + + assert r_json["answer"] == agent_level_answer + assert response.status_code == 200 diff --git a/backend/tests/regression/answer_quality/agent_test.py b/backend/tests/regression/answer_quality/agent_test_script.py similarity index 100% rename from backend/tests/regression/answer_quality/agent_test.py rename to backend/tests/regression/answer_quality/agent_test_script.py