Feature/agentic buffered (#4231)

* rename agent test script to prevent pytest autodiscovery

* first cut

* fix log message

* fix up typing

* add a sample test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
rkuo-danswer 2025-03-10 08:48:42 -07:00 committed by GitHub
parent 6ca400ced9
commit 426883bbf5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 305 additions and 18 deletions

View File

@ -1,10 +1,14 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@ -1,3 +1,5 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@ -895,7 +895,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
async def current_chat_accessible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()

View File

@ -1,10 +1,13 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@ -44,9 +47,44 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = (
AgentSearchPacket = Union[
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
)
]
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@ -234,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()

View File

@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@ -112,7 +112,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
user: User = Depends(current_chat_accesssible_user),
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs_for_user(

View File

@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@ -390,7 +390,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@ -7,7 +7,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
@ -191,7 +191,7 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

View File

@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@ -190,7 +190,7 @@ def update_chat_session_model(
def get_chat_session(
session_id: UUID,
is_shared: bool = False,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
user_id = user.id if user is not None else None
@ -246,7 +246,7 @@ def get_chat_session(
@router.post("/create-chat-session")
def create_new_chat_session(
chat_session_creation_request: ChatSessionCreationRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
@ -473,7 +473,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@ -11,7 +11,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time

View File

@ -1,3 +1,8 @@
from typing import Any
import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
@ -17,3 +22,58 @@ def test_send_message_simple_with_history(reset: None) -> None:
)
assert len(response.full_message) > 0
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)
def test_send_message_simple_with_history_buffered() -> None:
import requests
API_KEY = "" # fill in for this to work
headers = {}
headers["Authorization"] = f"Bearer {API_KEY}"
req: dict[str, Any] = {}
req["persona_id"] = 0
req["description"] = "test_send_message_simple_with_history_buffered"
response = requests.post(
f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req
)
chat_session_id = response.json()["chat_session_id"]
req = {}
req["chat_session_id"] = chat_session_id
req["message"] = "What does onyx do?"
req["use_agentic_search"] = True
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req
)
r_json = response.json()
# all of these should exist and be greater than length 1
assert len(r_json.get("answer", "")) > 0
assert len(r_json.get("agent_sub_questions", "")) > 0
assert len(r_json.get("agent_answers")) > 0
assert len(r_json.get("agent_sub_queries")) > 0
assert "agent_refined_answer_improvement" in r_json
# top level answer should match the one we select out of agent_answers
answer_level = 0
agent_level_answer = ""
agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement")
if agent_refined_answer_improvement:
answer_level = len(r_json["agent_answers"]) - 1
answers = r_json["agent_answers"][str(answer_level)]
for answer in answers:
if answer["answer_type"] == "agent_level_answer":
agent_level_answer = answer["answer"]
break
assert r_json["answer"] == agent_level_answer
assert response.status_code == 200