mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-17 13:22:42 +01:00
Feature/agentic buffered (#4231)
* rename agent test script to prevent pytest autodiscovery * first cut * fix log message * fix up typing * add a sample test --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
parent
6ca400ced9
commit
426883bbf5
@ -1,10 +1,14 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
@ -1,3 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
|
@ -895,7 +895,7 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
async def current_chat_accessible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
@ -1,10 +1,13 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@ -44,9 +47,44 @@ class LlmDoc(BaseModel):
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
"""None represents references to objects in the original flow. To our understanding,
|
||||
these will not be None in the packets returned from agent search.
|
||||
"""
|
||||
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level(
|
||||
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
|
||||
) -> dict[int, list["SubQuestionIdentifier"]]:
|
||||
"""returns a dict of level to object list (sorted by level_question_num)
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
|
||||
# organize by level, then sort ascending by question_index
|
||||
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
|
||||
|
||||
# group by level
|
||||
for k, obj in original_dict.items():
|
||||
level = k[0]
|
||||
if level not in level_dict:
|
||||
level_dict[level] = []
|
||||
level_dict[level].append(obj)
|
||||
|
||||
# for each level, sort the group
|
||||
for k2, value2 in level_dict.items():
|
||||
# we need to handle the none case due to SubQuestionIdentifier typing
|
||||
# level_question_num as int | None, even though it should never be None here.
|
||||
level_dict[k2] = sorted(
|
||||
value2,
|
||||
key=lambda x: (x.level_question_num is None, x.level_question_num),
|
||||
)
|
||||
|
||||
# sort by level
|
||||
sorted_dict = OrderedDict(sorted(level_dict.items()))
|
||||
return sorted_dict
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
||||
@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier):
|
||||
|
||||
|
||||
class SubQuestionPiece(SubQuestionIdentifier):
|
||||
"""Refined sub questions generated from the initial user question."""
|
||||
|
||||
sub_question: str
|
||||
|
||||
|
||||
@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel):
|
||||
refined_answer_improvement: bool
|
||||
|
||||
|
||||
AgentSearchPacket = (
|
||||
AgentSearchPacket = Union[
|
||||
SubQuestionPiece
|
||||
| AgentAnswerPiece
|
||||
| SubQueryPiece
|
||||
| ExtendedToolResponse
|
||||
| RefinedAnswerImprovement
|
||||
)
|
||||
]
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
||||
|
@ -234,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
yield
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await close_auth_limiter()
|
||||
|
||||
|
@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@ -112,7 +112,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
|
@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
|
||||
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
user: User = Depends(current_chat_accesssible_user),
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
|
@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@ -390,7 +390,7 @@ def get_image_generation_tool(
|
||||
|
||||
@basic_router.get("")
|
||||
def list_personas(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
|
@ -7,7 +7,7 @@ from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
@ -191,7 +191,7 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
|
@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import extract_headers
|
||||
@ -190,7 +190,7 @@ def update_chat_session_model(
|
||||
def get_chat_session(
|
||||
session_id: UUID,
|
||||
is_shared: bool = False,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
@ -246,7 +246,7 @@ def get_chat_session(
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
@ -473,7 +473,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
@ -11,7 +11,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
|
||||
|
||||
|
||||
def check_token_rate_limits(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
|
@ -1,3 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@ -17,3 +22,58 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
)
|
||||
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable for autorun when we have a testing environment with semantically useful data"
|
||||
)
|
||||
def test_send_message_simple_with_history_buffered() -> None:
|
||||
import requests
|
||||
|
||||
API_KEY = "" # fill in for this to work
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {API_KEY}"
|
||||
|
||||
req: dict[str, Any] = {}
|
||||
|
||||
req["persona_id"] = 0
|
||||
req["description"] = "test_send_message_simple_with_history_buffered"
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/create-chat-session", headers=headers, json=req
|
||||
)
|
||||
chat_session_id = response.json()["chat_session_id"]
|
||||
|
||||
req = {}
|
||||
req["chat_session_id"] = chat_session_id
|
||||
req["message"] = "What does onyx do?"
|
||||
req["use_agentic_search"] = True
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-api", headers=headers, json=req
|
||||
)
|
||||
|
||||
r_json = response.json()
|
||||
|
||||
# all of these should exist and be greater than length 1
|
||||
assert len(r_json.get("answer", "")) > 0
|
||||
assert len(r_json.get("agent_sub_questions", "")) > 0
|
||||
assert len(r_json.get("agent_answers")) > 0
|
||||
assert len(r_json.get("agent_sub_queries")) > 0
|
||||
assert "agent_refined_answer_improvement" in r_json
|
||||
|
||||
# top level answer should match the one we select out of agent_answers
|
||||
answer_level = 0
|
||||
agent_level_answer = ""
|
||||
|
||||
agent_refined_answer_improvement = r_json.get("agent_refined_answer_improvement")
|
||||
if agent_refined_answer_improvement:
|
||||
answer_level = len(r_json["agent_answers"]) - 1
|
||||
|
||||
answers = r_json["agent_answers"][str(answer_level)]
|
||||
for answer in answers:
|
||||
if answer["answer_type"] == "agent_level_answer":
|
||||
agent_level_answer = answer["answer"]
|
||||
break
|
||||
|
||||
assert r_json["answer"] == agent_level_answer
|
||||
assert response.status_code == 200
|
||||
|
Loading…
x
Reference in New Issue
Block a user